koichi12 commited on
Commit
9e4dfa6
·
verified ·
1 Parent(s): 2ca8472

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. .venv/lib/python3.11/site-packages/transformers/__pycache__/__init__.cpython-311.pyc +3 -0
  3. .venv/lib/python3.11/site-packages/transformers/models/cpm/__init__.py +27 -0
  4. .venv/lib/python3.11/site-packages/transformers/models/cpm/__pycache__/__init__.cpython-311.pyc +0 -0
  5. .venv/lib/python3.11/site-packages/transformers/models/cpm/__pycache__/tokenization_cpm.cpython-311.pyc +0 -0
  6. .venv/lib/python3.11/site-packages/transformers/models/cpm/__pycache__/tokenization_cpm_fast.cpython-311.pyc +0 -0
  7. .venv/lib/python3.11/site-packages/transformers/models/cpm/tokenization_cpm.py +348 -0
  8. .venv/lib/python3.11/site-packages/transformers/models/cpm/tokenization_cpm_fast.py +241 -0
  9. .venv/lib/python3.11/site-packages/transformers/models/cvt/__init__.py +28 -0
  10. .venv/lib/python3.11/site-packages/transformers/models/cvt/__pycache__/__init__.cpython-311.pyc +0 -0
  11. .venv/lib/python3.11/site-packages/transformers/models/cvt/__pycache__/configuration_cvt.cpython-311.pyc +0 -0
  12. .venv/lib/python3.11/site-packages/transformers/models/cvt/__pycache__/modeling_cvt.cpython-311.pyc +0 -0
  13. .venv/lib/python3.11/site-packages/transformers/models/cvt/__pycache__/modeling_tf_cvt.cpython-311.pyc +0 -0
  14. .venv/lib/python3.11/site-packages/transformers/models/cvt/configuration_cvt.py +146 -0
  15. .venv/lib/python3.11/site-packages/transformers/models/cvt/modeling_cvt.py +725 -0
  16. .venv/lib/python3.11/site-packages/transformers/models/cvt/modeling_tf_cvt.py +1096 -0
  17. .venv/lib/python3.11/site-packages/transformers/models/encoder_decoder/__init__.py +29 -0
  18. .venv/lib/python3.11/site-packages/transformers/models/encoder_decoder/__pycache__/__init__.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/transformers/models/encoder_decoder/__pycache__/configuration_encoder_decoder.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/transformers/models/encoder_decoder/__pycache__/modeling_encoder_decoder.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/transformers/models/encoder_decoder/__pycache__/modeling_flax_encoder_decoder.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/transformers/models/encoder_decoder/__pycache__/modeling_tf_encoder_decoder.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/transformers/models/encoder_decoder/configuration_encoder_decoder.py +111 -0
  24. .venv/lib/python3.11/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py +687 -0
  25. .venv/lib/python3.11/site-packages/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py +901 -0
  26. .venv/lib/python3.11/site-packages/transformers/models/mpt/__init__.py +27 -0
  27. .venv/lib/python3.11/site-packages/transformers/models/mpt/__pycache__/__init__.cpython-311.pyc +0 -0
  28. .venv/lib/python3.11/site-packages/transformers/models/mpt/__pycache__/configuration_mpt.cpython-311.pyc +0 -0
  29. .venv/lib/python3.11/site-packages/transformers/models/mpt/__pycache__/modeling_mpt.cpython-311.pyc +0 -0
  30. .venv/lib/python3.11/site-packages/transformers/models/mpt/configuration_mpt.py +233 -0
  31. .venv/lib/python3.11/site-packages/transformers/models/mpt/modeling_mpt.py +917 -0
  32. .venv/lib/python3.11/site-packages/transformers/models/olmo/__init__.py +59 -0
  33. .venv/lib/python3.11/site-packages/transformers/models/olmo/__pycache__/__init__.cpython-311.pyc +0 -0
  34. .venv/lib/python3.11/site-packages/transformers/models/olmo/__pycache__/configuration_olmo.cpython-311.pyc +0 -0
  35. .venv/lib/python3.11/site-packages/transformers/models/olmo/__pycache__/modeling_olmo.cpython-311.pyc +0 -0
  36. .venv/lib/python3.11/site-packages/transformers/models/olmo/__pycache__/modular_olmo.cpython-311.pyc +0 -0
  37. .venv/lib/python3.11/site-packages/transformers/models/olmo/configuration_olmo.py +181 -0
  38. .venv/lib/python3.11/site-packages/transformers/models/olmo/modeling_olmo.py +842 -0
  39. .venv/lib/python3.11/site-packages/transformers/models/olmo/modular_olmo.py +126 -0
  40. .venv/lib/python3.11/site-packages/transformers/models/rt_detr/__init__.py +33 -0
  41. .venv/lib/python3.11/site-packages/transformers/models/rt_detr/__pycache__/__init__.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/transformers/models/rt_detr/__pycache__/configuration_rt_detr.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/transformers/models/rt_detr/__pycache__/configuration_rt_detr_resnet.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/transformers/models/rt_detr/__pycache__/image_processing_rt_detr.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/transformers/models/rt_detr/__pycache__/image_processing_rt_detr_fast.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/transformers/models/rt_detr/__pycache__/modeling_rt_detr_resnet.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/transformers/models/rt_detr/__pycache__/modular_rt_detr.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/transformers/models/rt_detr/configuration_rt_detr.py +364 -0
  49. .venv/lib/python3.11/site-packages/transformers/models/rt_detr/configuration_rt_detr_resnet.py +114 -0
  50. .venv/lib/python3.11/site-packages/transformers/models/rt_detr/image_processing_rt_detr.py +1102 -0
.gitattributes CHANGED
@@ -433,3 +433,7 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/nvidia/cudnn/lib/
433
  .venv/lib/python3.11/site-packages/transformers/__pycache__/modeling_tf_utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
434
  .venv/lib/python3.11/site-packages/transformers/__pycache__/training_args.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
435
  .venv/lib/python3.11/site-packages/transformers/__pycache__/testing_utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
433
  .venv/lib/python3.11/site-packages/transformers/__pycache__/modeling_tf_utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
434
  .venv/lib/python3.11/site-packages/transformers/__pycache__/training_args.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
435
  .venv/lib/python3.11/site-packages/transformers/__pycache__/testing_utils.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
436
+ .venv/lib/python3.11/site-packages/transformers/__pycache__/__init__.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
437
+ .venv/lib/python3.11/site-packages/transformers/models/wav2vec2/__pycache__/modeling_tf_wav2vec2.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
438
+ .venv/lib/python3.11/site-packages/transformers/models/wav2vec2_conformer/__pycache__/modeling_wav2vec2_conformer.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
439
+ .venv/lib/python3.11/site-packages/transformers/models/whisper/__pycache__/modeling_whisper.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
.venv/lib/python3.11/site-packages/transformers/__pycache__/__init__.cpython-311.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db4550dc6bc1e2a167f35d80dff7e798a3afa52c7d1dc8bade4f96cfb2f778ae
3
+ size 272603
.venv/lib/python3.11/site-packages/transformers/models/cpm/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .tokenization_cpm import *
22
+ from .tokenization_cpm_fast import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
.venv/lib/python3.11/site-packages/transformers/models/cpm/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (765 Bytes). View file
 
.venv/lib/python3.11/site-packages/transformers/models/cpm/__pycache__/tokenization_cpm.cpython-311.pyc ADDED
Binary file (18.1 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/cpm/__pycache__/tokenization_cpm_fast.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/cpm/tokenization_cpm.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes."""
16
+
17
+ import os
18
+ import unicodedata
19
+ from shutil import copyfile
20
+ from typing import Any, Dict, List, Optional, Tuple
21
+
22
+ import sentencepiece as spm
23
+
24
+ from ...tokenization_utils import AddedToken, PreTrainedTokenizer
25
+ from ...utils import SPIECE_UNDERLINE, logging
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
31
+
32
+
33
+ class CpmTokenizer(PreTrainedTokenizer):
34
+ """Runs pre-tokenization with Jieba segmentation tool. It is used in CPM models."""
35
+
36
+ vocab_files_names = VOCAB_FILES_NAMES
37
+
38
+ def __init__(
39
+ self,
40
+ vocab_file,
41
+ do_lower_case=False,
42
+ remove_space=True,
43
+ keep_accents=False,
44
+ bos_token="<s>",
45
+ eos_token="</s>",
46
+ unk_token="<unk>",
47
+ sep_token="<sep>",
48
+ pad_token="<pad>",
49
+ cls_token="<cls>",
50
+ mask_token="<mask>",
51
+ additional_special_tokens=["<eop>", "<eod>"],
52
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
53
+ **kwargs,
54
+ ) -> None:
55
+ """
56
+ Construct a CPM tokenizer. Based on [Jieba](https://pypi.org/project/jieba/) and
57
+ [SentencePiece](https://github.com/google/sentencepiece).
58
+
59
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should
60
+ refer to this superclass for more information regarding those methods.
61
+
62
+ Args:
63
+ vocab_file (`str`):
64
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm extension) that
65
+ contains the vocabulary necessary to instantiate a tokenizer.
66
+ do_lower_case (`bool`, *optional*, defaults to `True`):
67
+ Whether to lowercase the input when tokenizing.
68
+ remove_space (`bool`, *optional*, defaults to `True`):
69
+ Whether to strip the text when tokenizing (removing excess spaces before and after the string).
70
+ keep_accents (`bool`, *optional*, defaults to `False`):
71
+ Whether to keep accents when tokenizing.
72
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
73
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier
74
+ token.
75
+
76
+ <Tip>
77
+
78
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
79
+ sequence. The token used is the `cls_token`.
80
+
81
+ </Tip>
82
+
83
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
84
+ The end of sequence token.
85
+
86
+ <Tip>
87
+
88
+ When building a sequence using special tokens, this is not the token that is used for the end of
89
+ sequence. The token used is the `sep_token`.
90
+
91
+ </Tip>
92
+
93
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
94
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be
95
+ this token instead.
96
+ sep_token (`str`, *optional*, defaults to `"<sep>"`):
97
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
98
+ for sequence classification or for a text and a question for question answering. It is also used as the
99
+ last token of a sequence built with special tokens.
100
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
101
+ The token used for padding, for example when batching sequences of different lengths.
102
+ cls_token (`str`, *optional*, defaults to `"<cls>"`):
103
+ The classifier token which is used when doing sequence classification (classification of the whole
104
+ sequence instead of per-token classification). It is the first token of the sequence when built with
105
+ special tokens.
106
+ mask_token (`str`, *optional*, defaults to `"<mask>"`):
107
+ The token used for masking values. This is the token used when training this model with masked language
108
+ modeling. This is the token which the model will try to predict.
109
+ additional_special_tokens (`List[str]`, *optional*, defaults to `["<eop>", "<eod>"]`):
110
+ Additional special tokens used by the tokenizer.
111
+
112
+ Attributes:
113
+ sp_model (`SentencePieceProcessor`):
114
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
115
+ """
116
+ # Mask token behave like a normal word, i.e. include the space before it
117
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
118
+
119
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
120
+
121
+ self.do_lower_case = do_lower_case
122
+ self.remove_space = remove_space
123
+ self.keep_accents = keep_accents
124
+ self.vocab_file = vocab_file
125
+
126
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
127
+ self.sp_model.Load(vocab_file)
128
+
129
+ try:
130
+ import jieba
131
+ except ModuleNotFoundError as error:
132
+ raise error.__class__(
133
+ "You need to install jieba to use CpmTokenizer or CpmTokenizerFast. "
134
+ "See https://pypi.org/project/jieba/ for installation."
135
+ )
136
+ self.jieba = jieba
137
+ self.translator = str.maketrans(" \n", "\u2582\u2583")
138
+
139
+ super().__init__(
140
+ do_lower_case=do_lower_case,
141
+ remove_space=remove_space,
142
+ keep_accents=keep_accents,
143
+ bos_token=bos_token,
144
+ eos_token=eos_token,
145
+ unk_token=unk_token,
146
+ sep_token=sep_token,
147
+ pad_token=pad_token,
148
+ cls_token=cls_token,
149
+ mask_token=mask_token,
150
+ additional_special_tokens=additional_special_tokens,
151
+ sp_model_kwargs=self.sp_model_kwargs,
152
+ **kwargs,
153
+ )
154
+
155
+ self._pad_token_type_id = 3
156
+
157
+ @property
158
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.vocab_size
159
+ def vocab_size(self):
160
+ return len(self.sp_model)
161
+
162
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.get_vocab
163
+ def get_vocab(self):
164
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
165
+ vocab.update(self.added_tokens_encoder)
166
+ return vocab
167
+
168
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.__getstate__
169
+ def __getstate__(self):
170
+ state = self.__dict__.copy()
171
+ state["sp_model"] = None
172
+ return state
173
+
174
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.__setstate__
175
+ def __setstate__(self, d):
176
+ self.__dict__ = d
177
+
178
+ # for backward compatibility
179
+ if not hasattr(self, "sp_model_kwargs"):
180
+ self.sp_model_kwargs = {}
181
+
182
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
183
+ self.sp_model.Load(self.vocab_file)
184
+
185
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.preprocess_text
186
+ def preprocess_text(self, inputs):
187
+ if self.remove_space:
188
+ outputs = " ".join(inputs.strip().split())
189
+ else:
190
+ outputs = inputs
191
+ outputs = outputs.replace("``", '"').replace("''", '"')
192
+
193
+ if not self.keep_accents:
194
+ outputs = unicodedata.normalize("NFKD", outputs)
195
+ outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
196
+ if self.do_lower_case:
197
+ outputs = outputs.lower()
198
+
199
+ return outputs
200
+
201
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer._tokenize
202
+ def _tokenize(self, text: str) -> List[str]:
203
+ """Tokenize a string."""
204
+ text = self.preprocess_text(text)
205
+ pieces = self.sp_model.encode(text, out_type=str)
206
+ new_pieces = []
207
+ for piece in pieces:
208
+ if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit():
209
+ cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
210
+ if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
211
+ if len(cur_pieces[0]) == 1:
212
+ cur_pieces = cur_pieces[1:]
213
+ else:
214
+ cur_pieces[0] = cur_pieces[0][1:]
215
+ cur_pieces.append(piece[-1])
216
+ new_pieces.extend(cur_pieces)
217
+ else:
218
+ new_pieces.append(piece)
219
+
220
+ return new_pieces
221
+
222
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer._convert_token_to_id
223
+ def _convert_token_to_id(self, token):
224
+ """Converts a token (str) in an id using the vocab."""
225
+ return self.sp_model.PieceToId(token)
226
+
227
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer._convert_id_to_token
228
+ def _convert_id_to_token(self, index):
229
+ """Converts an index (integer) in a token (str) using the vocab."""
230
+ return self.sp_model.IdToPiece(index)
231
+
232
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.convert_tokens_to_string
233
+ def convert_tokens_to_string(self, tokens):
234
+ """Converts a sequence of tokens (strings for sub-words) in a single string."""
235
+ out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
236
+ return out_string
237
+
238
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.build_inputs_with_special_tokens
239
+ def build_inputs_with_special_tokens(
240
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
241
+ ) -> List[int]:
242
+ """
243
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
244
+ adding special tokens. An XLNet sequence has the following format:
245
+
246
+ - single sequence: `X <sep> <cls>`
247
+ - pair of sequences: `A <sep> B <sep> <cls>`
248
+
249
+ Args:
250
+ token_ids_0 (`List[int]`):
251
+ List of IDs to which the special tokens will be added.
252
+ token_ids_1 (`List[int]`, *optional*):
253
+ Optional second list of IDs for sequence pairs.
254
+
255
+ Returns:
256
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
257
+ """
258
+ sep = [self.sep_token_id]
259
+ cls = [self.cls_token_id]
260
+ if token_ids_1 is None:
261
+ return token_ids_0 + sep + cls
262
+ return token_ids_0 + sep + token_ids_1 + sep + cls
263
+
264
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.get_special_tokens_mask
265
+ def get_special_tokens_mask(
266
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
267
+ ) -> List[int]:
268
+ """
269
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
270
+ special tokens using the tokenizer `prepare_for_model` method.
271
+
272
+ Args:
273
+ token_ids_0 (`List[int]`):
274
+ List of IDs.
275
+ token_ids_1 (`List[int]`, *optional*):
276
+ Optional second list of IDs for sequence pairs.
277
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
278
+ Whether or not the token list is already formatted with special tokens for the model.
279
+
280
+ Returns:
281
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
282
+ """
283
+
284
+ if already_has_special_tokens:
285
+ return super().get_special_tokens_mask(
286
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
287
+ )
288
+
289
+ if token_ids_1 is not None:
290
+ return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1, 1]
291
+ return ([0] * len(token_ids_0)) + [1, 1]
292
+
293
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.create_token_type_ids_from_sequences
294
+ def create_token_type_ids_from_sequences(
295
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
296
+ ) -> List[int]:
297
+ """
298
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. An XLNet
299
+ sequence pair mask has the following format:
300
+
301
+ ```
302
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
303
+ | first sequence | second sequence |
304
+ ```
305
+
306
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
307
+
308
+ Args:
309
+ token_ids_0 (`List[int]`):
310
+ List of IDs.
311
+ token_ids_1 (`List[int]`, *optional*):
312
+ Optional second list of IDs for sequence pairs.
313
+
314
+ Returns:
315
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
316
+ """
317
+ sep = [self.sep_token_id]
318
+ cls_segment_id = [2]
319
+
320
+ if token_ids_1 is None:
321
+ return len(token_ids_0 + sep) * [0] + cls_segment_id
322
+ return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + cls_segment_id
323
+
324
+ # Copied from transformers.models.xlnet.tokenization_xlnet.XLNetTokenizer.save_vocabulary
325
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
326
+ if not os.path.isdir(save_directory):
327
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
328
+ return
329
+ out_vocab_file = os.path.join(
330
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
331
+ )
332
+
333
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
334
+ copyfile(self.vocab_file, out_vocab_file)
335
+ elif not os.path.isfile(self.vocab_file):
336
+ with open(out_vocab_file, "wb") as fi:
337
+ content_spiece_model = self.sp_model.serialized_model_proto()
338
+ fi.write(content_spiece_model)
339
+
340
+ return (out_vocab_file,)
341
+
342
+ def _decode(self, *args, **kwargs):
343
+ text = super()._decode(*args, **kwargs)
344
+ text = text.replace(" ", "").replace("\u2582", " ").replace("\u2583", "\n")
345
+ return text
346
+
347
+
348
+ __all__ = ["CpmTokenizer"]
.venv/lib/python3.11/site-packages/transformers/models/cpm/tokenization_cpm_fast.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes."""
16
+
17
+ import os
18
+ from shutil import copyfile
19
+ from typing import List, Optional, Tuple
20
+
21
+ from ...tokenization_utils_fast import AddedToken, PreTrainedTokenizerFast
22
+ from ...utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"}
28
+
29
+
30
+ class CpmTokenizerFast(PreTrainedTokenizerFast):
31
+ """Runs pre-tokenization with Jieba segmentation tool. It is used in CPM models."""
32
+
33
+ def __init__(
34
+ self,
35
+ vocab_file=None,
36
+ tokenizer_file=None,
37
+ do_lower_case=False,
38
+ remove_space=True,
39
+ keep_accents=False,
40
+ bos_token="<s>",
41
+ eos_token="</s>",
42
+ unk_token="<unk>",
43
+ sep_token="<sep>",
44
+ pad_token="<pad>",
45
+ cls_token="<cls>",
46
+ mask_token="<mask>",
47
+ additional_special_tokens=["<eop>", "<eod>"],
48
+ **kwargs,
49
+ ):
50
+ """
51
+ Construct a CPM tokenizer. Based on [Jieba](https://pypi.org/project/jieba/) and
52
+ [SentencePiece](https://github.com/google/sentencepiece).
53
+
54
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should
55
+ refer to this superclass for more information regarding those methods.
56
+
57
+ Args:
58
+ vocab_file (`str`):
59
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm extension) that
60
+ contains the vocabulary necessary to instantiate a tokenizer.
61
+ do_lower_case (`bool`, *optional*, defaults to `True`):
62
+ Whether to lowercase the input when tokenizing.
63
+ remove_space (`bool`, *optional*, defaults to `True`):
64
+ Whether to strip the text when tokenizing (removing excess spaces before and after the string).
65
+ keep_accents (`bool`, *optional*, defaults to `False`):
66
+ Whether to keep accents when tokenizing.
67
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
68
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier
69
+ token.
70
+
71
+ <Tip>
72
+
73
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
74
+ sequence. The token used is the `cls_token`.
75
+
76
+ </Tip>
77
+
78
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
79
+ The end of sequence token.
80
+
81
+ <Tip>
82
+
83
+ When building a sequence using special tokens, this is not the token that is used for the end of
84
+ sequence. The token used is the `sep_token`.
85
+
86
+ </Tip>
87
+
88
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
89
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be
90
+ this token instead.
91
+ sep_token (`str`, *optional*, defaults to `"<sep>"`):
92
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
93
+ for sequence classification or for a text and a question for question answering. It is also used as the
94
+ last token of a sequence built with special tokens.
95
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
96
+ The token used for padding, for example when batching sequences of different lengths.
97
+ cls_token (`str`, *optional*, defaults to `"<cls>"`):
98
+ The classifier token which is used when doing sequence classification (classification of the whole
99
+ sequence instead of per-token classification). It is the first token of the sequence when built with
100
+ special tokens.
101
+ mask_token (`str`, *optional*, defaults to `"<mask>"`):
102
+ The token used for masking values. This is the token used when training this model with masked language
103
+ modeling. This is the token which the model will try to predict.
104
+ additional_special_tokens (`List[str]`, *optional*, defaults to `["<eop>", "<eod>"]`):
105
+ Additional special tokens used by the tokenizer.
106
+
107
+ Attributes:
108
+ sp_model (`SentencePieceProcessor`):
109
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
110
+ """
111
+ # Mask token behave like a normal word, i.e. include the space before it
112
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
113
+
114
+ super().__init__(
115
+ vocab_file=vocab_file,
116
+ tokenizer_file=tokenizer_file,
117
+ do_lower_case=do_lower_case,
118
+ remove_space=remove_space,
119
+ keep_accents=keep_accents,
120
+ bos_token=bos_token,
121
+ eos_token=eos_token,
122
+ unk_token=unk_token,
123
+ sep_token=sep_token,
124
+ pad_token=pad_token,
125
+ cls_token=cls_token,
126
+ mask_token=mask_token,
127
+ additional_special_tokens=additional_special_tokens,
128
+ **kwargs,
129
+ )
130
+
131
+ self._pad_token_type_id = 3
132
+ self.do_lower_case = do_lower_case
133
+ self.remove_space = remove_space
134
+ self.keep_accents = keep_accents
135
+ self.vocab_file = vocab_file
136
+
137
+ try:
138
+ import jieba
139
+ except ModuleNotFoundError as error:
140
+ raise error.__class__(
141
+ "You need to install jieba to use CpmTokenizer or CpmTokenizerFast. "
142
+ "See https://pypi.org/project/jieba/ for installation."
143
+ )
144
+ self.jieba = jieba
145
+ self.translator = str.maketrans(" \n", "\u2582\u2583")
146
+
147
+ @property
148
+ def can_save_slow_tokenizer(self) -> bool:
149
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
150
+
151
+ # Copied from transformers.models.xlnet.tokenization_xlnet_fast.XLNetTokenizerFast.build_inputs_with_special_tokens
152
+ def build_inputs_with_special_tokens(
153
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
154
+ ) -> List[int]:
155
+ """
156
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
157
+ adding special tokens. An XLNet sequence has the following format:
158
+
159
+ - single sequence: `X <sep> <cls>`
160
+ - pair of sequences: `A <sep> B <sep> <cls>`
161
+
162
+ Args:
163
+ token_ids_0 (`List[int]`):
164
+ List of IDs to which the special tokens will be added.
165
+ token_ids_1 (`List[int]`, *optional*):
166
+ Optional second list of IDs for sequence pairs.
167
+
168
+ Returns:
169
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
170
+ """
171
+ sep = [self.sep_token_id]
172
+ cls = [self.cls_token_id]
173
+ if token_ids_1 is None:
174
+ return token_ids_0 + sep + cls
175
+ return token_ids_0 + sep + token_ids_1 + sep + cls
176
+
177
+ # Copied from transformers.models.xlnet.tokenization_xlnet_fast.XLNetTokenizerFast.create_token_type_ids_from_sequences
178
+ def create_token_type_ids_from_sequences(
179
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
180
+ ) -> List[int]:
181
+ """
182
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. An XLNet
183
+ sequence pair mask has the following format:
184
+
185
+ ```
186
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
187
+ | first sequence | second sequence |
188
+ ```
189
+
190
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
191
+
192
+ Args:
193
+ token_ids_0 (`List[int]`):
194
+ List of IDs.
195
+ token_ids_1 (`List[int]`, *optional*):
196
+ Optional second list of IDs for sequence pairs.
197
+
198
+ Returns:
199
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
200
+ """
201
+ sep = [self.sep_token_id]
202
+ cls_segment_id = [2]
203
+
204
+ if token_ids_1 is None:
205
+ return len(token_ids_0 + sep) * [0] + cls_segment_id
206
+ return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + cls_segment_id
207
+
208
+ # Copied from transformers.models.xlnet.tokenization_xlnet_fast.XLNetTokenizerFast.save_vocabulary
209
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
210
+ if not self.can_save_slow_tokenizer:
211
+ raise ValueError(
212
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
213
+ "tokenizer."
214
+ )
215
+
216
+ if not os.path.isdir(save_directory):
217
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
218
+ return
219
+ out_vocab_file = os.path.join(
220
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
221
+ )
222
+
223
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
224
+ copyfile(self.vocab_file, out_vocab_file)
225
+
226
+ return (out_vocab_file,)
227
+
228
+ def _batch_encode_plus(self, batch_text_or_text_pairs, *args, **kwargs):
229
+ batch_text_or_text_pairs = [
230
+ " ".join([x.translate(self.translator) for x in self.jieba.cut(text, cut_all=False)])
231
+ for text in batch_text_or_text_pairs
232
+ ]
233
+ return super()._batch_encode_plus(batch_text_or_text_pairs, *args, **kwargs)
234
+
235
+ def _decode(self, *args, **kwargs):
236
+ text = super()._decode(*args, **kwargs)
237
+ text = text.replace(" ", "").replace("\u2582", " ").replace("\u2583", "\n")
238
+ return text
239
+
240
+
241
+ __all__ = ["CpmTokenizerFast"]
.venv/lib/python3.11/site-packages/transformers/models/cvt/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_cvt import *
22
+ from .modeling_cvt import *
23
+ from .modeling_tf_cvt import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
.venv/lib/python3.11/site-packages/transformers/models/cvt/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (794 Bytes). View file
 
.venv/lib/python3.11/site-packages/transformers/models/cvt/__pycache__/configuration_cvt.cpython-311.pyc ADDED
Binary file (6.64 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/cvt/__pycache__/modeling_cvt.cpython-311.pyc ADDED
Binary file (39.4 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/cvt/__pycache__/modeling_tf_cvt.cpython-311.pyc ADDED
Binary file (60.6 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/cvt/configuration_cvt.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """CvT model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class CvtConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`CvtModel`]. It is used to instantiate a CvT model
27
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
28
+ defaults will yield a similar configuration to that of the CvT
29
+ [microsoft/cvt-13](https://huggingface.co/microsoft/cvt-13) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ num_channels (`int`, *optional*, defaults to 3):
36
+ The number of input channels.
37
+ patch_sizes (`List[int]`, *optional*, defaults to `[7, 3, 3]`):
38
+ The kernel size of each encoder's patch embedding.
39
+ patch_stride (`List[int]`, *optional*, defaults to `[4, 2, 2]`):
40
+ The stride size of each encoder's patch embedding.
41
+ patch_padding (`List[int]`, *optional*, defaults to `[2, 1, 1]`):
42
+ The padding size of each encoder's patch embedding.
43
+ embed_dim (`List[int]`, *optional*, defaults to `[64, 192, 384]`):
44
+ Dimension of each of the encoder blocks.
45
+ num_heads (`List[int]`, *optional*, defaults to `[1, 3, 6]`):
46
+ Number of attention heads for each attention layer in each block of the Transformer encoder.
47
+ depth (`List[int]`, *optional*, defaults to `[1, 2, 10]`):
48
+ The number of layers in each encoder block.
49
+ mlp_ratios (`List[float]`, *optional*, defaults to `[4.0, 4.0, 4.0, 4.0]`):
50
+ Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the
51
+ encoder blocks.
52
+ attention_drop_rate (`List[float]`, *optional*, defaults to `[0.0, 0.0, 0.0]`):
53
+ The dropout ratio for the attention probabilities.
54
+ drop_rate (`List[float]`, *optional*, defaults to `[0.0, 0.0, 0.0]`):
55
+ The dropout ratio for the patch embeddings probabilities.
56
+ drop_path_rate (`List[float]`, *optional*, defaults to `[0.0, 0.0, 0.1]`):
57
+ The dropout probability for stochastic depth, used in the blocks of the Transformer encoder.
58
+ qkv_bias (`List[bool]`, *optional*, defaults to `[True, True, True]`):
59
+ The bias bool for query, key and value in attentions
60
+ cls_token (`List[bool]`, *optional*, defaults to `[False, False, True]`):
61
+ Whether or not to add a classification token to the output of each of the last 3 stages.
62
+ qkv_projection_method (`List[string]`, *optional*, defaults to ["dw_bn", "dw_bn", "dw_bn"]`):
63
+ The projection method for query, key and value Default is depth-wise convolutions with batch norm. For
64
+ Linear projection use "avg".
65
+ kernel_qkv (`List[int]`, *optional*, defaults to `[3, 3, 3]`):
66
+ The kernel size for query, key and value in attention layer
67
+ padding_kv (`List[int]`, *optional*, defaults to `[1, 1, 1]`):
68
+ The padding size for key and value in attention layer
69
+ stride_kv (`List[int]`, *optional*, defaults to `[2, 2, 2]`):
70
+ The stride size for key and value in attention layer
71
+ padding_q (`List[int]`, *optional*, defaults to `[1, 1, 1]`):
72
+ The padding size for query in attention layer
73
+ stride_q (`List[int]`, *optional*, defaults to `[1, 1, 1]`):
74
+ The stride size for query in attention layer
75
+ initializer_range (`float`, *optional*, defaults to 0.02):
76
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
77
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
78
+ The epsilon used by the layer normalization layers.
79
+
80
+ Example:
81
+
82
+ ```python
83
+ >>> from transformers import CvtConfig, CvtModel
84
+
85
+ >>> # Initializing a Cvt msft/cvt style configuration
86
+ >>> configuration = CvtConfig()
87
+
88
+ >>> # Initializing a model (with random weights) from the msft/cvt style configuration
89
+ >>> model = CvtModel(configuration)
90
+
91
+ >>> # Accessing the model configuration
92
+ >>> configuration = model.config
93
+ ```"""
94
+
95
+ model_type = "cvt"
96
+
97
+ def __init__(
98
+ self,
99
+ num_channels=3,
100
+ patch_sizes=[7, 3, 3],
101
+ patch_stride=[4, 2, 2],
102
+ patch_padding=[2, 1, 1],
103
+ embed_dim=[64, 192, 384],
104
+ num_heads=[1, 3, 6],
105
+ depth=[1, 2, 10],
106
+ mlp_ratio=[4.0, 4.0, 4.0],
107
+ attention_drop_rate=[0.0, 0.0, 0.0],
108
+ drop_rate=[0.0, 0.0, 0.0],
109
+ drop_path_rate=[0.0, 0.0, 0.1],
110
+ qkv_bias=[True, True, True],
111
+ cls_token=[False, False, True],
112
+ qkv_projection_method=["dw_bn", "dw_bn", "dw_bn"],
113
+ kernel_qkv=[3, 3, 3],
114
+ padding_kv=[1, 1, 1],
115
+ stride_kv=[2, 2, 2],
116
+ padding_q=[1, 1, 1],
117
+ stride_q=[1, 1, 1],
118
+ initializer_range=0.02,
119
+ layer_norm_eps=1e-12,
120
+ **kwargs,
121
+ ):
122
+ super().__init__(**kwargs)
123
+ self.num_channels = num_channels
124
+ self.patch_sizes = patch_sizes
125
+ self.patch_stride = patch_stride
126
+ self.patch_padding = patch_padding
127
+ self.embed_dim = embed_dim
128
+ self.num_heads = num_heads
129
+ self.depth = depth
130
+ self.mlp_ratio = mlp_ratio
131
+ self.attention_drop_rate = attention_drop_rate
132
+ self.drop_rate = drop_rate
133
+ self.drop_path_rate = drop_path_rate
134
+ self.qkv_bias = qkv_bias
135
+ self.cls_token = cls_token
136
+ self.qkv_projection_method = qkv_projection_method
137
+ self.kernel_qkv = kernel_qkv
138
+ self.padding_kv = padding_kv
139
+ self.stride_kv = stride_kv
140
+ self.padding_q = padding_q
141
+ self.stride_q = stride_q
142
+ self.initializer_range = initializer_range
143
+ self.layer_norm_eps = layer_norm_eps
144
+
145
+
146
+ __all__ = ["CvtConfig"]
.venv/lib/python3.11/site-packages/transformers/models/cvt/modeling_cvt.py ADDED
@@ -0,0 +1,725 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch CvT model."""
16
+
17
+ import collections.abc
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
27
+ from ...modeling_outputs import ImageClassifierOutputWithNoAttention, ModelOutput
28
+ from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
29
+ from ...utils import logging
30
+ from .configuration_cvt import CvtConfig
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+ # General docstring
36
+ _CONFIG_FOR_DOC = "CvtConfig"
37
+
38
+ # Base docstring
39
+ _CHECKPOINT_FOR_DOC = "microsoft/cvt-13"
40
+ _EXPECTED_OUTPUT_SHAPE = [1, 384, 14, 14]
41
+
42
+ # Image classification docstring
43
+ _IMAGE_CLASS_CHECKPOINT = "microsoft/cvt-13"
44
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
45
+
46
+
47
+ @dataclass
48
+ class BaseModelOutputWithCLSToken(ModelOutput):
49
+ """
50
+ Base class for model's outputs, with potential hidden states and attentions.
51
+
52
+ Args:
53
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
54
+ Sequence of hidden-states at the output of the last layer of the model.
55
+ cls_token_value (`torch.FloatTensor` of shape `(batch_size, 1, hidden_size)`):
56
+ Classification token at the output of the last layer of the model.
57
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
58
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
59
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
60
+ plus the initial embedding outputs.
61
+ """
62
+
63
+ last_hidden_state: torch.FloatTensor = None
64
+ cls_token_value: torch.FloatTensor = None
65
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
66
+
67
+
68
+ # Copied from transformers.models.beit.modeling_beit.drop_path
69
+ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
70
+ """
71
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
72
+
73
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
74
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
75
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
76
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
77
+ argument.
78
+ """
79
+ if drop_prob == 0.0 or not training:
80
+ return input
81
+ keep_prob = 1 - drop_prob
82
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
83
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
84
+ random_tensor.floor_() # binarize
85
+ output = input.div(keep_prob) * random_tensor
86
+ return output
87
+
88
+
89
+ # Copied from transformers.models.beit.modeling_beit.BeitDropPath
90
+ class CvtDropPath(nn.Module):
91
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
92
+
93
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
94
+ super().__init__()
95
+ self.drop_prob = drop_prob
96
+
97
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
98
+ return drop_path(hidden_states, self.drop_prob, self.training)
99
+
100
+ def extra_repr(self) -> str:
101
+ return "p={}".format(self.drop_prob)
102
+
103
+
104
+ class CvtEmbeddings(nn.Module):
105
+ """
106
+ Construct the CvT embeddings.
107
+ """
108
+
109
+ def __init__(self, patch_size, num_channels, embed_dim, stride, padding, dropout_rate):
110
+ super().__init__()
111
+ self.convolution_embeddings = CvtConvEmbeddings(
112
+ patch_size=patch_size, num_channels=num_channels, embed_dim=embed_dim, stride=stride, padding=padding
113
+ )
114
+ self.dropout = nn.Dropout(dropout_rate)
115
+
116
+ def forward(self, pixel_values):
117
+ hidden_state = self.convolution_embeddings(pixel_values)
118
+ hidden_state = self.dropout(hidden_state)
119
+ return hidden_state
120
+
121
+
122
+ class CvtConvEmbeddings(nn.Module):
123
+ """
124
+ Image to Conv Embedding.
125
+ """
126
+
127
+ def __init__(self, patch_size, num_channels, embed_dim, stride, padding):
128
+ super().__init__()
129
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
130
+ self.patch_size = patch_size
131
+ self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=stride, padding=padding)
132
+ self.normalization = nn.LayerNorm(embed_dim)
133
+
134
+ def forward(self, pixel_values):
135
+ pixel_values = self.projection(pixel_values)
136
+ batch_size, num_channels, height, width = pixel_values.shape
137
+ hidden_size = height * width
138
+ # rearrange "b c h w -> b (h w) c"
139
+ pixel_values = pixel_values.view(batch_size, num_channels, hidden_size).permute(0, 2, 1)
140
+ if self.normalization:
141
+ pixel_values = self.normalization(pixel_values)
142
+ # rearrange "b (h w) c" -> b c h w"
143
+ pixel_values = pixel_values.permute(0, 2, 1).view(batch_size, num_channels, height, width)
144
+ return pixel_values
145
+
146
+
147
+ class CvtSelfAttentionConvProjection(nn.Module):
148
+ def __init__(self, embed_dim, kernel_size, padding, stride):
149
+ super().__init__()
150
+ self.convolution = nn.Conv2d(
151
+ embed_dim,
152
+ embed_dim,
153
+ kernel_size=kernel_size,
154
+ padding=padding,
155
+ stride=stride,
156
+ bias=False,
157
+ groups=embed_dim,
158
+ )
159
+ self.normalization = nn.BatchNorm2d(embed_dim)
160
+
161
+ def forward(self, hidden_state):
162
+ hidden_state = self.convolution(hidden_state)
163
+ hidden_state = self.normalization(hidden_state)
164
+ return hidden_state
165
+
166
+
167
+ class CvtSelfAttentionLinearProjection(nn.Module):
168
+ def forward(self, hidden_state):
169
+ batch_size, num_channels, height, width = hidden_state.shape
170
+ hidden_size = height * width
171
+ # rearrange " b c h w -> b (h w) c"
172
+ hidden_state = hidden_state.view(batch_size, num_channels, hidden_size).permute(0, 2, 1)
173
+ return hidden_state
174
+
175
+
176
+ class CvtSelfAttentionProjection(nn.Module):
177
+ def __init__(self, embed_dim, kernel_size, padding, stride, projection_method="dw_bn"):
178
+ super().__init__()
179
+ if projection_method == "dw_bn":
180
+ self.convolution_projection = CvtSelfAttentionConvProjection(embed_dim, kernel_size, padding, stride)
181
+ self.linear_projection = CvtSelfAttentionLinearProjection()
182
+
183
+ def forward(self, hidden_state):
184
+ hidden_state = self.convolution_projection(hidden_state)
185
+ hidden_state = self.linear_projection(hidden_state)
186
+ return hidden_state
187
+
188
+
189
+ class CvtSelfAttention(nn.Module):
190
+ def __init__(
191
+ self,
192
+ num_heads,
193
+ embed_dim,
194
+ kernel_size,
195
+ padding_q,
196
+ padding_kv,
197
+ stride_q,
198
+ stride_kv,
199
+ qkv_projection_method,
200
+ qkv_bias,
201
+ attention_drop_rate,
202
+ with_cls_token=True,
203
+ **kwargs,
204
+ ):
205
+ super().__init__()
206
+ self.scale = embed_dim**-0.5
207
+ self.with_cls_token = with_cls_token
208
+ self.embed_dim = embed_dim
209
+ self.num_heads = num_heads
210
+
211
+ self.convolution_projection_query = CvtSelfAttentionProjection(
212
+ embed_dim,
213
+ kernel_size,
214
+ padding_q,
215
+ stride_q,
216
+ projection_method="linear" if qkv_projection_method == "avg" else qkv_projection_method,
217
+ )
218
+ self.convolution_projection_key = CvtSelfAttentionProjection(
219
+ embed_dim, kernel_size, padding_kv, stride_kv, projection_method=qkv_projection_method
220
+ )
221
+ self.convolution_projection_value = CvtSelfAttentionProjection(
222
+ embed_dim, kernel_size, padding_kv, stride_kv, projection_method=qkv_projection_method
223
+ )
224
+
225
+ self.projection_query = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
226
+ self.projection_key = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
227
+ self.projection_value = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
228
+
229
+ self.dropout = nn.Dropout(attention_drop_rate)
230
+
231
+ def rearrange_for_multi_head_attention(self, hidden_state):
232
+ batch_size, hidden_size, _ = hidden_state.shape
233
+ head_dim = self.embed_dim // self.num_heads
234
+ # rearrange 'b t (h d) -> b h t d'
235
+ return hidden_state.view(batch_size, hidden_size, self.num_heads, head_dim).permute(0, 2, 1, 3)
236
+
237
+ def forward(self, hidden_state, height, width):
238
+ if self.with_cls_token:
239
+ cls_token, hidden_state = torch.split(hidden_state, [1, height * width], 1)
240
+ batch_size, hidden_size, num_channels = hidden_state.shape
241
+ # rearrange "b (h w) c -> b c h w"
242
+ hidden_state = hidden_state.permute(0, 2, 1).view(batch_size, num_channels, height, width)
243
+
244
+ key = self.convolution_projection_key(hidden_state)
245
+ query = self.convolution_projection_query(hidden_state)
246
+ value = self.convolution_projection_value(hidden_state)
247
+
248
+ if self.with_cls_token:
249
+ query = torch.cat((cls_token, query), dim=1)
250
+ key = torch.cat((cls_token, key), dim=1)
251
+ value = torch.cat((cls_token, value), dim=1)
252
+
253
+ head_dim = self.embed_dim // self.num_heads
254
+
255
+ query = self.rearrange_for_multi_head_attention(self.projection_query(query))
256
+ key = self.rearrange_for_multi_head_attention(self.projection_key(key))
257
+ value = self.rearrange_for_multi_head_attention(self.projection_value(value))
258
+
259
+ attention_score = torch.einsum("bhlk,bhtk->bhlt", [query, key]) * self.scale
260
+ attention_probs = torch.nn.functional.softmax(attention_score, dim=-1)
261
+ attention_probs = self.dropout(attention_probs)
262
+
263
+ context = torch.einsum("bhlt,bhtv->bhlv", [attention_probs, value])
264
+ # rearrange"b h t d -> b t (h d)"
265
+ _, _, hidden_size, _ = context.shape
266
+ context = context.permute(0, 2, 1, 3).contiguous().view(batch_size, hidden_size, self.num_heads * head_dim)
267
+ return context
268
+
269
+
270
+ class CvtSelfOutput(nn.Module):
271
+ """
272
+ The residual connection is defined in CvtLayer instead of here (as is the case with other models), due to the
273
+ layernorm applied before each block.
274
+ """
275
+
276
+ def __init__(self, embed_dim, drop_rate):
277
+ super().__init__()
278
+ self.dense = nn.Linear(embed_dim, embed_dim)
279
+ self.dropout = nn.Dropout(drop_rate)
280
+
281
+ def forward(self, hidden_state, input_tensor):
282
+ hidden_state = self.dense(hidden_state)
283
+ hidden_state = self.dropout(hidden_state)
284
+ return hidden_state
285
+
286
+
287
+ class CvtAttention(nn.Module):
288
+ def __init__(
289
+ self,
290
+ num_heads,
291
+ embed_dim,
292
+ kernel_size,
293
+ padding_q,
294
+ padding_kv,
295
+ stride_q,
296
+ stride_kv,
297
+ qkv_projection_method,
298
+ qkv_bias,
299
+ attention_drop_rate,
300
+ drop_rate,
301
+ with_cls_token=True,
302
+ ):
303
+ super().__init__()
304
+ self.attention = CvtSelfAttention(
305
+ num_heads,
306
+ embed_dim,
307
+ kernel_size,
308
+ padding_q,
309
+ padding_kv,
310
+ stride_q,
311
+ stride_kv,
312
+ qkv_projection_method,
313
+ qkv_bias,
314
+ attention_drop_rate,
315
+ with_cls_token,
316
+ )
317
+ self.output = CvtSelfOutput(embed_dim, drop_rate)
318
+ self.pruned_heads = set()
319
+
320
+ def prune_heads(self, heads):
321
+ if len(heads) == 0:
322
+ return
323
+ heads, index = find_pruneable_heads_and_indices(
324
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
325
+ )
326
+
327
+ # Prune linear layers
328
+ self.attention.query = prune_linear_layer(self.attention.query, index)
329
+ self.attention.key = prune_linear_layer(self.attention.key, index)
330
+ self.attention.value = prune_linear_layer(self.attention.value, index)
331
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
332
+
333
+ # Update hyper params and store pruned heads
334
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
335
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
336
+ self.pruned_heads = self.pruned_heads.union(heads)
337
+
338
+ def forward(self, hidden_state, height, width):
339
+ self_output = self.attention(hidden_state, height, width)
340
+ attention_output = self.output(self_output, hidden_state)
341
+ return attention_output
342
+
343
+
344
+ class CvtIntermediate(nn.Module):
345
+ def __init__(self, embed_dim, mlp_ratio):
346
+ super().__init__()
347
+ self.dense = nn.Linear(embed_dim, int(embed_dim * mlp_ratio))
348
+ self.activation = nn.GELU()
349
+
350
+ def forward(self, hidden_state):
351
+ hidden_state = self.dense(hidden_state)
352
+ hidden_state = self.activation(hidden_state)
353
+ return hidden_state
354
+
355
+
356
+ class CvtOutput(nn.Module):
357
+ def __init__(self, embed_dim, mlp_ratio, drop_rate):
358
+ super().__init__()
359
+ self.dense = nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
360
+ self.dropout = nn.Dropout(drop_rate)
361
+
362
+ def forward(self, hidden_state, input_tensor):
363
+ hidden_state = self.dense(hidden_state)
364
+ hidden_state = self.dropout(hidden_state)
365
+ hidden_state = hidden_state + input_tensor
366
+ return hidden_state
367
+
368
+
369
+ class CvtLayer(nn.Module):
370
+ """
371
+ CvtLayer composed by attention layers, normalization and multi-layer perceptrons (mlps).
372
+ """
373
+
374
+ def __init__(
375
+ self,
376
+ num_heads,
377
+ embed_dim,
378
+ kernel_size,
379
+ padding_q,
380
+ padding_kv,
381
+ stride_q,
382
+ stride_kv,
383
+ qkv_projection_method,
384
+ qkv_bias,
385
+ attention_drop_rate,
386
+ drop_rate,
387
+ mlp_ratio,
388
+ drop_path_rate,
389
+ with_cls_token=True,
390
+ ):
391
+ super().__init__()
392
+ self.attention = CvtAttention(
393
+ num_heads,
394
+ embed_dim,
395
+ kernel_size,
396
+ padding_q,
397
+ padding_kv,
398
+ stride_q,
399
+ stride_kv,
400
+ qkv_projection_method,
401
+ qkv_bias,
402
+ attention_drop_rate,
403
+ drop_rate,
404
+ with_cls_token,
405
+ )
406
+
407
+ self.intermediate = CvtIntermediate(embed_dim, mlp_ratio)
408
+ self.output = CvtOutput(embed_dim, mlp_ratio, drop_rate)
409
+ self.drop_path = CvtDropPath(drop_prob=drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
410
+ self.layernorm_before = nn.LayerNorm(embed_dim)
411
+ self.layernorm_after = nn.LayerNorm(embed_dim)
412
+
413
+ def forward(self, hidden_state, height, width):
414
+ self_attention_output = self.attention(
415
+ self.layernorm_before(hidden_state), # in Cvt, layernorm is applied before self-attention
416
+ height,
417
+ width,
418
+ )
419
+ attention_output = self_attention_output
420
+ attention_output = self.drop_path(attention_output)
421
+
422
+ # first residual connection
423
+ hidden_state = attention_output + hidden_state
424
+
425
+ # in Cvt, layernorm is also applied after self-attention
426
+ layer_output = self.layernorm_after(hidden_state)
427
+ layer_output = self.intermediate(layer_output)
428
+
429
+ # second residual connection is done here
430
+ layer_output = self.output(layer_output, hidden_state)
431
+ layer_output = self.drop_path(layer_output)
432
+ return layer_output
433
+
434
+
435
+ class CvtStage(nn.Module):
436
+ def __init__(self, config, stage):
437
+ super().__init__()
438
+ self.config = config
439
+ self.stage = stage
440
+ if self.config.cls_token[self.stage]:
441
+ self.cls_token = nn.Parameter(torch.randn(1, 1, self.config.embed_dim[-1]))
442
+
443
+ self.embedding = CvtEmbeddings(
444
+ patch_size=config.patch_sizes[self.stage],
445
+ stride=config.patch_stride[self.stage],
446
+ num_channels=config.num_channels if self.stage == 0 else config.embed_dim[self.stage - 1],
447
+ embed_dim=config.embed_dim[self.stage],
448
+ padding=config.patch_padding[self.stage],
449
+ dropout_rate=config.drop_rate[self.stage],
450
+ )
451
+
452
+ drop_path_rates = [x.item() for x in torch.linspace(0, config.drop_path_rate[self.stage], config.depth[stage])]
453
+
454
+ self.layers = nn.Sequential(
455
+ *[
456
+ CvtLayer(
457
+ num_heads=config.num_heads[self.stage],
458
+ embed_dim=config.embed_dim[self.stage],
459
+ kernel_size=config.kernel_qkv[self.stage],
460
+ padding_q=config.padding_q[self.stage],
461
+ padding_kv=config.padding_kv[self.stage],
462
+ stride_kv=config.stride_kv[self.stage],
463
+ stride_q=config.stride_q[self.stage],
464
+ qkv_projection_method=config.qkv_projection_method[self.stage],
465
+ qkv_bias=config.qkv_bias[self.stage],
466
+ attention_drop_rate=config.attention_drop_rate[self.stage],
467
+ drop_rate=config.drop_rate[self.stage],
468
+ drop_path_rate=drop_path_rates[self.stage],
469
+ mlp_ratio=config.mlp_ratio[self.stage],
470
+ with_cls_token=config.cls_token[self.stage],
471
+ )
472
+ for _ in range(config.depth[self.stage])
473
+ ]
474
+ )
475
+
476
+ def forward(self, hidden_state):
477
+ cls_token = None
478
+ hidden_state = self.embedding(hidden_state)
479
+ batch_size, num_channels, height, width = hidden_state.shape
480
+ # rearrange b c h w -> b (h w) c"
481
+ hidden_state = hidden_state.view(batch_size, num_channels, height * width).permute(0, 2, 1)
482
+ if self.config.cls_token[self.stage]:
483
+ cls_token = self.cls_token.expand(batch_size, -1, -1)
484
+ hidden_state = torch.cat((cls_token, hidden_state), dim=1)
485
+
486
+ for layer in self.layers:
487
+ layer_outputs = layer(hidden_state, height, width)
488
+ hidden_state = layer_outputs
489
+
490
+ if self.config.cls_token[self.stage]:
491
+ cls_token, hidden_state = torch.split(hidden_state, [1, height * width], 1)
492
+ hidden_state = hidden_state.permute(0, 2, 1).view(batch_size, num_channels, height, width)
493
+ return hidden_state, cls_token
494
+
495
+
496
+ class CvtEncoder(nn.Module):
497
+ def __init__(self, config):
498
+ super().__init__()
499
+ self.config = config
500
+ self.stages = nn.ModuleList([])
501
+ for stage_idx in range(len(config.depth)):
502
+ self.stages.append(CvtStage(config, stage_idx))
503
+
504
+ def forward(self, pixel_values, output_hidden_states=False, return_dict=True):
505
+ all_hidden_states = () if output_hidden_states else None
506
+ hidden_state = pixel_values
507
+
508
+ cls_token = None
509
+ for _, (stage_module) in enumerate(self.stages):
510
+ hidden_state, cls_token = stage_module(hidden_state)
511
+ if output_hidden_states:
512
+ all_hidden_states = all_hidden_states + (hidden_state,)
513
+
514
+ if not return_dict:
515
+ return tuple(v for v in [hidden_state, cls_token, all_hidden_states] if v is not None)
516
+
517
+ return BaseModelOutputWithCLSToken(
518
+ last_hidden_state=hidden_state,
519
+ cls_token_value=cls_token,
520
+ hidden_states=all_hidden_states,
521
+ )
522
+
523
+
524
+ class CvtPreTrainedModel(PreTrainedModel):
525
+ """
526
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
527
+ models.
528
+ """
529
+
530
+ config_class = CvtConfig
531
+ base_model_prefix = "cvt"
532
+ main_input_name = "pixel_values"
533
+ _no_split_modules = ["CvtLayer"]
534
+
535
+ def _init_weights(self, module):
536
+ """Initialize the weights"""
537
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
538
+ module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
539
+ if module.bias is not None:
540
+ module.bias.data.zero_()
541
+ elif isinstance(module, nn.LayerNorm):
542
+ module.bias.data.zero_()
543
+ module.weight.data.fill_(1.0)
544
+ elif isinstance(module, CvtStage):
545
+ if self.config.cls_token[module.stage]:
546
+ module.cls_token.data = nn.init.trunc_normal_(
547
+ torch.zeros(1, 1, self.config.embed_dim[-1]), mean=0.0, std=self.config.initializer_range
548
+ )
549
+
550
+
551
+ CVT_START_DOCSTRING = r"""
552
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
553
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
554
+ behavior.
555
+
556
+ Parameters:
557
+ config ([`CvtConfig`]): Model configuration class with all the parameters of the model.
558
+ Initializing with a config file does not load the weights associated with the model, only the
559
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
560
+ """
561
+
562
+ CVT_INPUTS_DOCSTRING = r"""
563
+ Args:
564
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
565
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CvtImageProcessor.__call__`]
566
+ for details.
567
+ output_hidden_states (`bool`, *optional*):
568
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
569
+ more detail.
570
+ return_dict (`bool`, *optional*):
571
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
572
+ """
573
+
574
+
575
+ @add_start_docstrings(
576
+ "The bare Cvt Model transformer outputting raw hidden-states without any specific head on top.",
577
+ CVT_START_DOCSTRING,
578
+ )
579
+ class CvtModel(CvtPreTrainedModel):
580
+ def __init__(self, config, add_pooling_layer=True):
581
+ super().__init__(config)
582
+ self.config = config
583
+ self.encoder = CvtEncoder(config)
584
+ self.post_init()
585
+
586
+ def _prune_heads(self, heads_to_prune):
587
+ """
588
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
589
+ class PreTrainedModel
590
+ """
591
+ for layer, heads in heads_to_prune.items():
592
+ self.encoder.layer[layer].attention.prune_heads(heads)
593
+
594
+ @add_start_docstrings_to_model_forward(CVT_INPUTS_DOCSTRING)
595
+ @add_code_sample_docstrings(
596
+ checkpoint=_CHECKPOINT_FOR_DOC,
597
+ output_type=BaseModelOutputWithCLSToken,
598
+ config_class=_CONFIG_FOR_DOC,
599
+ modality="vision",
600
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
601
+ )
602
+ def forward(
603
+ self,
604
+ pixel_values: Optional[torch.Tensor] = None,
605
+ output_hidden_states: Optional[bool] = None,
606
+ return_dict: Optional[bool] = None,
607
+ ) -> Union[Tuple, BaseModelOutputWithCLSToken]:
608
+ output_hidden_states = (
609
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
610
+ )
611
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
612
+
613
+ if pixel_values is None:
614
+ raise ValueError("You have to specify pixel_values")
615
+
616
+ encoder_outputs = self.encoder(
617
+ pixel_values,
618
+ output_hidden_states=output_hidden_states,
619
+ return_dict=return_dict,
620
+ )
621
+ sequence_output = encoder_outputs[0]
622
+
623
+ if not return_dict:
624
+ return (sequence_output,) + encoder_outputs[1:]
625
+
626
+ return BaseModelOutputWithCLSToken(
627
+ last_hidden_state=sequence_output,
628
+ cls_token_value=encoder_outputs.cls_token_value,
629
+ hidden_states=encoder_outputs.hidden_states,
630
+ )
631
+
632
+
633
+ @add_start_docstrings(
634
+ """
635
+ Cvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
636
+ the [CLS] token) e.g. for ImageNet.
637
+ """,
638
+ CVT_START_DOCSTRING,
639
+ )
640
+ class CvtForImageClassification(CvtPreTrainedModel):
641
+ def __init__(self, config):
642
+ super().__init__(config)
643
+
644
+ self.num_labels = config.num_labels
645
+ self.cvt = CvtModel(config, add_pooling_layer=False)
646
+ self.layernorm = nn.LayerNorm(config.embed_dim[-1])
647
+ # Classifier head
648
+ self.classifier = (
649
+ nn.Linear(config.embed_dim[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
650
+ )
651
+
652
+ # Initialize weights and apply final processing
653
+ self.post_init()
654
+
655
+ @add_start_docstrings_to_model_forward(CVT_INPUTS_DOCSTRING)
656
+ @add_code_sample_docstrings(
657
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
658
+ output_type=ImageClassifierOutputWithNoAttention,
659
+ config_class=_CONFIG_FOR_DOC,
660
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
661
+ )
662
+ def forward(
663
+ self,
664
+ pixel_values: Optional[torch.Tensor] = None,
665
+ labels: Optional[torch.Tensor] = None,
666
+ output_hidden_states: Optional[bool] = None,
667
+ return_dict: Optional[bool] = None,
668
+ ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
669
+ r"""
670
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
671
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
672
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
673
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
674
+ """
675
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
676
+ outputs = self.cvt(
677
+ pixel_values,
678
+ output_hidden_states=output_hidden_states,
679
+ return_dict=return_dict,
680
+ )
681
+
682
+ sequence_output = outputs[0]
683
+ cls_token = outputs[1]
684
+ if self.config.cls_token[-1]:
685
+ sequence_output = self.layernorm(cls_token)
686
+ else:
687
+ batch_size, num_channels, height, width = sequence_output.shape
688
+ # rearrange "b c h w -> b (h w) c"
689
+ sequence_output = sequence_output.view(batch_size, num_channels, height * width).permute(0, 2, 1)
690
+ sequence_output = self.layernorm(sequence_output)
691
+
692
+ sequence_output_mean = sequence_output.mean(dim=1)
693
+ logits = self.classifier(sequence_output_mean)
694
+
695
+ loss = None
696
+ if labels is not None:
697
+ if self.config.problem_type is None:
698
+ if self.config.num_labels == 1:
699
+ self.config.problem_type = "regression"
700
+ elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
701
+ self.config.problem_type = "single_label_classification"
702
+ else:
703
+ self.config.problem_type = "multi_label_classification"
704
+
705
+ if self.config.problem_type == "regression":
706
+ loss_fct = MSELoss()
707
+ if self.config.num_labels == 1:
708
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
709
+ else:
710
+ loss = loss_fct(logits, labels)
711
+ elif self.config.problem_type == "single_label_classification":
712
+ loss_fct = CrossEntropyLoss()
713
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
714
+ elif self.config.problem_type == "multi_label_classification":
715
+ loss_fct = BCEWithLogitsLoss()
716
+ loss = loss_fct(logits, labels)
717
+
718
+ if not return_dict:
719
+ output = (logits,) + outputs[2:]
720
+ return ((loss,) + output) if loss is not None else output
721
+
722
+ return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
723
+
724
+
725
+ __all__ = ["CvtForImageClassification", "CvtModel", "CvtPreTrainedModel"]
.venv/lib/python3.11/site-packages/transformers/models/cvt/modeling_tf_cvt.py ADDED
@@ -0,0 +1,1096 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """TF 2.0 Cvt model."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import collections.abc
20
+ from dataclasses import dataclass
21
+ from typing import Optional, Tuple, Union
22
+
23
+ import tensorflow as tf
24
+
25
+ from ...modeling_tf_outputs import TFImageClassifierOutputWithNoAttention
26
+ from ...modeling_tf_utils import (
27
+ TFModelInputType,
28
+ TFPreTrainedModel,
29
+ TFSequenceClassificationLoss,
30
+ get_initializer,
31
+ keras,
32
+ keras_serializable,
33
+ unpack_inputs,
34
+ )
35
+ from ...tf_utils import shape_list, stable_softmax
36
+ from ...utils import (
37
+ ModelOutput,
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+ logging,
41
+ replace_return_docstrings,
42
+ )
43
+ from .configuration_cvt import CvtConfig
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ # General docstring
49
+ _CONFIG_FOR_DOC = "CvtConfig"
50
+
51
+
52
+ @dataclass
53
+ class TFBaseModelOutputWithCLSToken(ModelOutput):
54
+ """
55
+ Base class for model's outputs.
56
+
57
+ Args:
58
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
59
+ Sequence of hidden-states at the output of the last layer of the model.
60
+ cls_token_value (`tf.Tensor` of shape `(batch_size, 1, hidden_size)`):
61
+ Classification token at the output of the last layer of the model.
62
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
63
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
64
+ `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
65
+ the initial embedding outputs.
66
+ """
67
+
68
+ last_hidden_state: tf.Tensor = None
69
+ cls_token_value: tf.Tensor = None
70
+ hidden_states: Tuple[tf.Tensor, ...] | None = None
71
+
72
+
73
+ class TFCvtDropPath(keras.layers.Layer):
74
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
75
+ References:
76
+ (1) github.com:rwightman/pytorch-image-models
77
+ """
78
+
79
+ def __init__(self, drop_prob: float, **kwargs):
80
+ super().__init__(**kwargs)
81
+ self.drop_prob = drop_prob
82
+
83
+ def call(self, x: tf.Tensor, training=None):
84
+ if self.drop_prob == 0.0 or not training:
85
+ return x
86
+ keep_prob = 1 - self.drop_prob
87
+ shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
88
+ random_tensor = keep_prob + tf.random.uniform(shape, 0, 1, dtype=self.compute_dtype)
89
+ random_tensor = tf.floor(random_tensor)
90
+ return (x / keep_prob) * random_tensor
91
+
92
+
93
+ class TFCvtEmbeddings(keras.layers.Layer):
94
+ """Construct the Convolutional Token Embeddings."""
95
+
96
+ def __init__(
97
+ self,
98
+ config: CvtConfig,
99
+ patch_size: int,
100
+ num_channels: int,
101
+ embed_dim: int,
102
+ stride: int,
103
+ padding: int,
104
+ dropout_rate: float,
105
+ **kwargs,
106
+ ):
107
+ super().__init__(**kwargs)
108
+ self.convolution_embeddings = TFCvtConvEmbeddings(
109
+ config,
110
+ patch_size=patch_size,
111
+ num_channels=num_channels,
112
+ embed_dim=embed_dim,
113
+ stride=stride,
114
+ padding=padding,
115
+ name="convolution_embeddings",
116
+ )
117
+ self.dropout = keras.layers.Dropout(dropout_rate)
118
+
119
+ def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
120
+ hidden_state = self.convolution_embeddings(pixel_values)
121
+ hidden_state = self.dropout(hidden_state, training=training)
122
+ return hidden_state
123
+
124
+ def build(self, input_shape=None):
125
+ if self.built:
126
+ return
127
+ self.built = True
128
+ if getattr(self, "convolution_embeddings", None) is not None:
129
+ with tf.name_scope(self.convolution_embeddings.name):
130
+ self.convolution_embeddings.build(None)
131
+
132
+
133
+ class TFCvtConvEmbeddings(keras.layers.Layer):
134
+ """Image to Convolution Embeddings. This convolutional operation aims to model local spatial contexts."""
135
+
136
+ def __init__(
137
+ self,
138
+ config: CvtConfig,
139
+ patch_size: int,
140
+ num_channels: int,
141
+ embed_dim: int,
142
+ stride: int,
143
+ padding: int,
144
+ **kwargs,
145
+ ):
146
+ super().__init__(**kwargs)
147
+ self.padding = keras.layers.ZeroPadding2D(padding=padding)
148
+ self.patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
149
+ self.projection = keras.layers.Conv2D(
150
+ filters=embed_dim,
151
+ kernel_size=patch_size,
152
+ strides=stride,
153
+ padding="valid",
154
+ data_format="channels_last",
155
+ kernel_initializer=get_initializer(config.initializer_range),
156
+ name="projection",
157
+ )
158
+ # Using the same default epsilon as PyTorch
159
+ self.normalization = keras.layers.LayerNormalization(epsilon=1e-5, name="normalization")
160
+ self.num_channels = num_channels
161
+ self.embed_dim = embed_dim
162
+
163
+ def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
164
+ if isinstance(pixel_values, dict):
165
+ pixel_values = pixel_values["pixel_values"]
166
+
167
+ pixel_values = self.projection(self.padding(pixel_values))
168
+
169
+ # "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels"
170
+ batch_size, height, width, num_channels = shape_list(pixel_values)
171
+ hidden_size = height * width
172
+ pixel_values = tf.reshape(pixel_values, shape=(batch_size, hidden_size, num_channels))
173
+ pixel_values = self.normalization(pixel_values)
174
+
175
+ # "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels"
176
+ pixel_values = tf.reshape(pixel_values, shape=(batch_size, height, width, num_channels))
177
+ return pixel_values
178
+
179
+ def build(self, input_shape=None):
180
+ if self.built:
181
+ return
182
+ self.built = True
183
+ if getattr(self, "projection", None) is not None:
184
+ with tf.name_scope(self.projection.name):
185
+ self.projection.build([None, None, None, self.num_channels])
186
+ if getattr(self, "normalization", None) is not None:
187
+ with tf.name_scope(self.normalization.name):
188
+ self.normalization.build([None, None, self.embed_dim])
189
+
190
+
191
+ class TFCvtSelfAttentionConvProjection(keras.layers.Layer):
192
+ """Convolutional projection layer."""
193
+
194
+ def __init__(self, config: CvtConfig, embed_dim: int, kernel_size: int, stride: int, padding: int, **kwargs):
195
+ super().__init__(**kwargs)
196
+ self.padding = keras.layers.ZeroPadding2D(padding=padding)
197
+ self.convolution = keras.layers.Conv2D(
198
+ filters=embed_dim,
199
+ kernel_size=kernel_size,
200
+ kernel_initializer=get_initializer(config.initializer_range),
201
+ padding="valid",
202
+ strides=stride,
203
+ use_bias=False,
204
+ name="convolution",
205
+ groups=embed_dim,
206
+ )
207
+ # Using the same default epsilon as PyTorch, TF uses (1 - pytorch momentum)
208
+ self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")
209
+ self.embed_dim = embed_dim
210
+
211
+ def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
212
+ hidden_state = self.convolution(self.padding(hidden_state))
213
+ hidden_state = self.normalization(hidden_state, training=training)
214
+ return hidden_state
215
+
216
+ def build(self, input_shape=None):
217
+ if self.built:
218
+ return
219
+ self.built = True
220
+ if getattr(self, "convolution", None) is not None:
221
+ with tf.name_scope(self.convolution.name):
222
+ self.convolution.build([None, None, None, self.embed_dim])
223
+ if getattr(self, "normalization", None) is not None:
224
+ with tf.name_scope(self.normalization.name):
225
+ self.normalization.build([None, None, None, self.embed_dim])
226
+
227
+
228
+ class TFCvtSelfAttentionLinearProjection(keras.layers.Layer):
229
+ """Linear projection layer used to flatten tokens into 1D."""
230
+
231
+ def call(self, hidden_state: tf.Tensor) -> tf.Tensor:
232
+ # "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels"
233
+ batch_size, height, width, num_channels = shape_list(hidden_state)
234
+ hidden_size = height * width
235
+ hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, num_channels))
236
+ return hidden_state
237
+
238
+
239
+ class TFCvtSelfAttentionProjection(keras.layers.Layer):
240
+ """Convolutional Projection for Attention."""
241
+
242
+ def __init__(
243
+ self,
244
+ config: CvtConfig,
245
+ embed_dim: int,
246
+ kernel_size: int,
247
+ stride: int,
248
+ padding: int,
249
+ projection_method: str = "dw_bn",
250
+ **kwargs,
251
+ ):
252
+ super().__init__(**kwargs)
253
+ if projection_method == "dw_bn":
254
+ self.convolution_projection = TFCvtSelfAttentionConvProjection(
255
+ config, embed_dim, kernel_size, stride, padding, name="convolution_projection"
256
+ )
257
+ self.linear_projection = TFCvtSelfAttentionLinearProjection()
258
+
259
+ def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
260
+ hidden_state = self.convolution_projection(hidden_state, training=training)
261
+ hidden_state = self.linear_projection(hidden_state)
262
+ return hidden_state
263
+
264
+ def build(self, input_shape=None):
265
+ if self.built:
266
+ return
267
+ self.built = True
268
+ if getattr(self, "convolution_projection", None) is not None:
269
+ with tf.name_scope(self.convolution_projection.name):
270
+ self.convolution_projection.build(None)
271
+
272
+
273
+ class TFCvtSelfAttention(keras.layers.Layer):
274
+ """
275
+ Self-attention layer. A depth-wise separable convolution operation (Convolutional Projection), is applied for
276
+ query, key, and value embeddings.
277
+ """
278
+
279
+ def __init__(
280
+ self,
281
+ config: CvtConfig,
282
+ num_heads: int,
283
+ embed_dim: int,
284
+ kernel_size: int,
285
+ stride_q: int,
286
+ stride_kv: int,
287
+ padding_q: int,
288
+ padding_kv: int,
289
+ qkv_projection_method: str,
290
+ qkv_bias: bool,
291
+ attention_drop_rate: float,
292
+ with_cls_token: bool = True,
293
+ **kwargs,
294
+ ):
295
+ super().__init__(**kwargs)
296
+ self.scale = embed_dim**-0.5
297
+ self.with_cls_token = with_cls_token
298
+ self.embed_dim = embed_dim
299
+ self.num_heads = num_heads
300
+
301
+ self.convolution_projection_query = TFCvtSelfAttentionProjection(
302
+ config,
303
+ embed_dim,
304
+ kernel_size,
305
+ stride_q,
306
+ padding_q,
307
+ projection_method="linear" if qkv_projection_method == "avg" else qkv_projection_method,
308
+ name="convolution_projection_query",
309
+ )
310
+ self.convolution_projection_key = TFCvtSelfAttentionProjection(
311
+ config,
312
+ embed_dim,
313
+ kernel_size,
314
+ stride_kv,
315
+ padding_kv,
316
+ projection_method=qkv_projection_method,
317
+ name="convolution_projection_key",
318
+ )
319
+ self.convolution_projection_value = TFCvtSelfAttentionProjection(
320
+ config,
321
+ embed_dim,
322
+ kernel_size,
323
+ stride_kv,
324
+ padding_kv,
325
+ projection_method=qkv_projection_method,
326
+ name="convolution_projection_value",
327
+ )
328
+
329
+ self.projection_query = keras.layers.Dense(
330
+ units=embed_dim,
331
+ kernel_initializer=get_initializer(config.initializer_range),
332
+ use_bias=qkv_bias,
333
+ bias_initializer="zeros",
334
+ name="projection_query",
335
+ )
336
+ self.projection_key = keras.layers.Dense(
337
+ units=embed_dim,
338
+ kernel_initializer=get_initializer(config.initializer_range),
339
+ use_bias=qkv_bias,
340
+ bias_initializer="zeros",
341
+ name="projection_key",
342
+ )
343
+ self.projection_value = keras.layers.Dense(
344
+ units=embed_dim,
345
+ kernel_initializer=get_initializer(config.initializer_range),
346
+ use_bias=qkv_bias,
347
+ bias_initializer="zeros",
348
+ name="projection_value",
349
+ )
350
+ self.dropout = keras.layers.Dropout(attention_drop_rate)
351
+
352
+ def rearrange_for_multi_head_attention(self, hidden_state: tf.Tensor) -> tf.Tensor:
353
+ batch_size, hidden_size, _ = shape_list(hidden_state)
354
+ head_dim = self.embed_dim // self.num_heads
355
+ hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, self.num_heads, head_dim))
356
+ hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1, 3))
357
+ return hidden_state
358
+
359
+ def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor:
360
+ if self.with_cls_token:
361
+ cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1)
362
+
363
+ # "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels"
364
+ batch_size, hidden_size, num_channels = shape_list(hidden_state)
365
+ hidden_state = tf.reshape(hidden_state, shape=(batch_size, height, width, num_channels))
366
+
367
+ key = self.convolution_projection_key(hidden_state, training=training)
368
+ query = self.convolution_projection_query(hidden_state, training=training)
369
+ value = self.convolution_projection_value(hidden_state, training=training)
370
+
371
+ if self.with_cls_token:
372
+ query = tf.concat((cls_token, query), axis=1)
373
+ key = tf.concat((cls_token, key), axis=1)
374
+ value = tf.concat((cls_token, value), axis=1)
375
+
376
+ head_dim = self.embed_dim // self.num_heads
377
+
378
+ query = self.rearrange_for_multi_head_attention(self.projection_query(query))
379
+ key = self.rearrange_for_multi_head_attention(self.projection_key(key))
380
+ value = self.rearrange_for_multi_head_attention(self.projection_value(value))
381
+
382
+ attention_score = tf.matmul(query, key, transpose_b=True) * self.scale
383
+ attention_probs = stable_softmax(logits=attention_score, axis=-1)
384
+ attention_probs = self.dropout(attention_probs, training=training)
385
+
386
+ context = tf.matmul(attention_probs, value)
387
+ # "batch_size, num_heads, hidden_size, head_dim -> batch_size, hidden_size, (num_heads*head_dim)"
388
+ _, _, hidden_size, _ = shape_list(context)
389
+ context = tf.transpose(context, perm=(0, 2, 1, 3))
390
+ context = tf.reshape(context, (batch_size, hidden_size, self.num_heads * head_dim))
391
+ return context
392
+
393
+ def build(self, input_shape=None):
394
+ if self.built:
395
+ return
396
+ self.built = True
397
+ if getattr(self, "convolution_projection_query", None) is not None:
398
+ with tf.name_scope(self.convolution_projection_query.name):
399
+ self.convolution_projection_query.build(None)
400
+ if getattr(self, "convolution_projection_key", None) is not None:
401
+ with tf.name_scope(self.convolution_projection_key.name):
402
+ self.convolution_projection_key.build(None)
403
+ if getattr(self, "convolution_projection_value", None) is not None:
404
+ with tf.name_scope(self.convolution_projection_value.name):
405
+ self.convolution_projection_value.build(None)
406
+ if getattr(self, "projection_query", None) is not None:
407
+ with tf.name_scope(self.projection_query.name):
408
+ self.projection_query.build([None, None, self.embed_dim])
409
+ if getattr(self, "projection_key", None) is not None:
410
+ with tf.name_scope(self.projection_key.name):
411
+ self.projection_key.build([None, None, self.embed_dim])
412
+ if getattr(self, "projection_value", None) is not None:
413
+ with tf.name_scope(self.projection_value.name):
414
+ self.projection_value.build([None, None, self.embed_dim])
415
+
416
+
417
+ class TFCvtSelfOutput(keras.layers.Layer):
418
+ """Output of the Attention layer ."""
419
+
420
+ def __init__(self, config: CvtConfig, embed_dim: int, drop_rate: float, **kwargs):
421
+ super().__init__(**kwargs)
422
+ self.dense = keras.layers.Dense(
423
+ units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="dense"
424
+ )
425
+ self.dropout = keras.layers.Dropout(drop_rate)
426
+ self.embed_dim = embed_dim
427
+
428
+ def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
429
+ hidden_state = self.dense(inputs=hidden_state)
430
+ hidden_state = self.dropout(inputs=hidden_state, training=training)
431
+ return hidden_state
432
+
433
+ def build(self, input_shape=None):
434
+ if self.built:
435
+ return
436
+ self.built = True
437
+ if getattr(self, "dense", None) is not None:
438
+ with tf.name_scope(self.dense.name):
439
+ self.dense.build([None, None, self.embed_dim])
440
+
441
+
442
+ class TFCvtAttention(keras.layers.Layer):
443
+ """Attention layer. First chunk of the convolutional transformer block."""
444
+
445
+ def __init__(
446
+ self,
447
+ config: CvtConfig,
448
+ num_heads: int,
449
+ embed_dim: int,
450
+ kernel_size: int,
451
+ stride_q: int,
452
+ stride_kv: int,
453
+ padding_q: int,
454
+ padding_kv: int,
455
+ qkv_projection_method: str,
456
+ qkv_bias: bool,
457
+ attention_drop_rate: float,
458
+ drop_rate: float,
459
+ with_cls_token: bool = True,
460
+ **kwargs,
461
+ ):
462
+ super().__init__(**kwargs)
463
+ self.attention = TFCvtSelfAttention(
464
+ config,
465
+ num_heads,
466
+ embed_dim,
467
+ kernel_size,
468
+ stride_q,
469
+ stride_kv,
470
+ padding_q,
471
+ padding_kv,
472
+ qkv_projection_method,
473
+ qkv_bias,
474
+ attention_drop_rate,
475
+ with_cls_token,
476
+ name="attention",
477
+ )
478
+ self.dense_output = TFCvtSelfOutput(config, embed_dim, drop_rate, name="output")
479
+
480
+ def prune_heads(self, heads):
481
+ raise NotImplementedError
482
+
483
+ def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False):
484
+ self_output = self.attention(hidden_state, height, width, training=training)
485
+ attention_output = self.dense_output(self_output, training=training)
486
+ return attention_output
487
+
488
+ def build(self, input_shape=None):
489
+ if self.built:
490
+ return
491
+ self.built = True
492
+ if getattr(self, "attention", None) is not None:
493
+ with tf.name_scope(self.attention.name):
494
+ self.attention.build(None)
495
+ if getattr(self, "dense_output", None) is not None:
496
+ with tf.name_scope(self.dense_output.name):
497
+ self.dense_output.build(None)
498
+
499
+
500
+ class TFCvtIntermediate(keras.layers.Layer):
501
+ """Intermediate dense layer. Second chunk of the convolutional transformer block."""
502
+
503
+ def __init__(self, config: CvtConfig, embed_dim: int, mlp_ratio: int, **kwargs):
504
+ super().__init__(**kwargs)
505
+ self.dense = keras.layers.Dense(
506
+ units=int(embed_dim * mlp_ratio),
507
+ kernel_initializer=get_initializer(config.initializer_range),
508
+ activation="gelu",
509
+ name="dense",
510
+ )
511
+ self.embed_dim = embed_dim
512
+
513
+ def call(self, hidden_state: tf.Tensor) -> tf.Tensor:
514
+ hidden_state = self.dense(hidden_state)
515
+ return hidden_state
516
+
517
+ def build(self, input_shape=None):
518
+ if self.built:
519
+ return
520
+ self.built = True
521
+ if getattr(self, "dense", None) is not None:
522
+ with tf.name_scope(self.dense.name):
523
+ self.dense.build([None, None, self.embed_dim])
524
+
525
+
526
+ class TFCvtOutput(keras.layers.Layer):
527
+ """
528
+ Output of the Convolutional Transformer Block (last chunk). It consists of a MLP and a residual connection.
529
+ """
530
+
531
+ def __init__(self, config: CvtConfig, embed_dim: int, mlp_ratio: int, drop_rate: int, **kwargs):
532
+ super().__init__(**kwargs)
533
+ self.dense = keras.layers.Dense(
534
+ units=embed_dim, kernel_initializer=get_initializer(config.initializer_range), name="dense"
535
+ )
536
+ self.dropout = keras.layers.Dropout(drop_rate)
537
+ self.embed_dim = embed_dim
538
+ self.mlp_ratio = mlp_ratio
539
+
540
+ def call(self, hidden_state: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
541
+ hidden_state = self.dense(inputs=hidden_state)
542
+ hidden_state = self.dropout(inputs=hidden_state, training=training)
543
+ hidden_state = hidden_state + input_tensor
544
+ return hidden_state
545
+
546
+ def build(self, input_shape=None):
547
+ if self.built:
548
+ return
549
+ self.built = True
550
+ if getattr(self, "dense", None) is not None:
551
+ with tf.name_scope(self.dense.name):
552
+ self.dense.build([None, None, int(self.embed_dim * self.mlp_ratio)])
553
+
554
+
555
+ class TFCvtLayer(keras.layers.Layer):
556
+ """
557
+ Convolutional Transformer Block composed by attention layers, normalization and multi-layer perceptrons (mlps). It
558
+ consists of 3 chunks : an attention layer, an intermediate dense layer and an output layer. This corresponds to the
559
+ `Block` class in the original implementation.
560
+ """
561
+
562
+ def __init__(
563
+ self,
564
+ config: CvtConfig,
565
+ num_heads: int,
566
+ embed_dim: int,
567
+ kernel_size: int,
568
+ stride_q: int,
569
+ stride_kv: int,
570
+ padding_q: int,
571
+ padding_kv: int,
572
+ qkv_projection_method: str,
573
+ qkv_bias: bool,
574
+ attention_drop_rate: float,
575
+ drop_rate: float,
576
+ mlp_ratio: float,
577
+ drop_path_rate: float,
578
+ with_cls_token: bool = True,
579
+ **kwargs,
580
+ ):
581
+ super().__init__(**kwargs)
582
+ self.attention = TFCvtAttention(
583
+ config,
584
+ num_heads,
585
+ embed_dim,
586
+ kernel_size,
587
+ stride_q,
588
+ stride_kv,
589
+ padding_q,
590
+ padding_kv,
591
+ qkv_projection_method,
592
+ qkv_bias,
593
+ attention_drop_rate,
594
+ drop_rate,
595
+ with_cls_token,
596
+ name="attention",
597
+ )
598
+ self.intermediate = TFCvtIntermediate(config, embed_dim, mlp_ratio, name="intermediate")
599
+ self.dense_output = TFCvtOutput(config, embed_dim, mlp_ratio, drop_rate, name="output")
600
+ # Using `layers.Activation` instead of `tf.identity` to better control `training` behaviour.
601
+ self.drop_path = (
602
+ TFCvtDropPath(drop_path_rate, name="drop_path")
603
+ if drop_path_rate > 0.0
604
+ else keras.layers.Activation("linear", name="drop_path")
605
+ )
606
+ # Using the same default epsilon as PyTorch
607
+ self.layernorm_before = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_before")
608
+ self.layernorm_after = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_after")
609
+ self.embed_dim = embed_dim
610
+
611
+ def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool = False) -> tf.Tensor:
612
+ # in Cvt, layernorm is applied before self-attention
613
+ attention_output = self.attention(self.layernorm_before(hidden_state), height, width, training=training)
614
+ attention_output = self.drop_path(attention_output, training=training)
615
+
616
+ # first residual connection
617
+ hidden_state = attention_output + hidden_state
618
+
619
+ # in Cvt, layernorm is also applied after self-attention
620
+ layer_output = self.layernorm_after(hidden_state)
621
+ layer_output = self.intermediate(layer_output)
622
+
623
+ # second residual connection is done here
624
+ layer_output = self.dense_output(layer_output, hidden_state)
625
+ layer_output = self.drop_path(layer_output, training=training)
626
+ return layer_output
627
+
628
+ def build(self, input_shape=None):
629
+ if self.built:
630
+ return
631
+ self.built = True
632
+ if getattr(self, "attention", None) is not None:
633
+ with tf.name_scope(self.attention.name):
634
+ self.attention.build(None)
635
+ if getattr(self, "intermediate", None) is not None:
636
+ with tf.name_scope(self.intermediate.name):
637
+ self.intermediate.build(None)
638
+ if getattr(self, "dense_output", None) is not None:
639
+ with tf.name_scope(self.dense_output.name):
640
+ self.dense_output.build(None)
641
+ if getattr(self, "drop_path", None) is not None:
642
+ with tf.name_scope(self.drop_path.name):
643
+ self.drop_path.build(None)
644
+ if getattr(self, "layernorm_before", None) is not None:
645
+ with tf.name_scope(self.layernorm_before.name):
646
+ self.layernorm_before.build([None, None, self.embed_dim])
647
+ if getattr(self, "layernorm_after", None) is not None:
648
+ with tf.name_scope(self.layernorm_after.name):
649
+ self.layernorm_after.build([None, None, self.embed_dim])
650
+
651
+
652
+ class TFCvtStage(keras.layers.Layer):
653
+ """
654
+ Cvt stage (encoder block). Each stage has 2 parts :
655
+ - (1) A Convolutional Token Embedding layer
656
+ - (2) A Convolutional Transformer Block (layer).
657
+ The classification token is added only in the last stage.
658
+
659
+ Args:
660
+ config ([`CvtConfig`]): Model configuration class.
661
+ stage (`int`): Stage number.
662
+ """
663
+
664
+ def __init__(self, config: CvtConfig, stage: int, **kwargs):
665
+ super().__init__(**kwargs)
666
+ self.config = config
667
+ self.stage = stage
668
+ if self.config.cls_token[self.stage]:
669
+ self.cls_token = self.add_weight(
670
+ shape=(1, 1, self.config.embed_dim[-1]),
671
+ initializer=get_initializer(self.config.initializer_range),
672
+ trainable=True,
673
+ name="cvt.encoder.stages.2.cls_token",
674
+ )
675
+
676
+ self.embedding = TFCvtEmbeddings(
677
+ self.config,
678
+ patch_size=config.patch_sizes[self.stage],
679
+ num_channels=config.num_channels if self.stage == 0 else config.embed_dim[self.stage - 1],
680
+ stride=config.patch_stride[self.stage],
681
+ embed_dim=config.embed_dim[self.stage],
682
+ padding=config.patch_padding[self.stage],
683
+ dropout_rate=config.drop_rate[self.stage],
684
+ name="embedding",
685
+ )
686
+
687
+ drop_path_rates = tf.linspace(0.0, config.drop_path_rate[self.stage], config.depth[stage])
688
+ drop_path_rates = [x.numpy().item() for x in drop_path_rates]
689
+ self.layers = [
690
+ TFCvtLayer(
691
+ config,
692
+ num_heads=config.num_heads[self.stage],
693
+ embed_dim=config.embed_dim[self.stage],
694
+ kernel_size=config.kernel_qkv[self.stage],
695
+ stride_q=config.stride_q[self.stage],
696
+ stride_kv=config.stride_kv[self.stage],
697
+ padding_q=config.padding_q[self.stage],
698
+ padding_kv=config.padding_kv[self.stage],
699
+ qkv_projection_method=config.qkv_projection_method[self.stage],
700
+ qkv_bias=config.qkv_bias[self.stage],
701
+ attention_drop_rate=config.attention_drop_rate[self.stage],
702
+ drop_rate=config.drop_rate[self.stage],
703
+ mlp_ratio=config.mlp_ratio[self.stage],
704
+ drop_path_rate=drop_path_rates[self.stage],
705
+ with_cls_token=config.cls_token[self.stage],
706
+ name=f"layers.{j}",
707
+ )
708
+ for j in range(config.depth[self.stage])
709
+ ]
710
+
711
+ def call(self, hidden_state: tf.Tensor, training: bool = False):
712
+ cls_token = None
713
+ hidden_state = self.embedding(hidden_state, training)
714
+
715
+ # "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels"
716
+ batch_size, height, width, num_channels = shape_list(hidden_state)
717
+ hidden_size = height * width
718
+ hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, num_channels))
719
+
720
+ if self.config.cls_token[self.stage]:
721
+ cls_token = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
722
+ hidden_state = tf.concat((cls_token, hidden_state), axis=1)
723
+
724
+ for layer in self.layers:
725
+ layer_outputs = layer(hidden_state, height, width, training=training)
726
+ hidden_state = layer_outputs
727
+
728
+ if self.config.cls_token[self.stage]:
729
+ cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1)
730
+
731
+ # "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels"
732
+ hidden_state = tf.reshape(hidden_state, shape=(batch_size, height, width, num_channels))
733
+ return hidden_state, cls_token
734
+
735
+ def build(self, input_shape=None):
736
+ if self.built:
737
+ return
738
+ self.built = True
739
+ if getattr(self, "embedding", None) is not None:
740
+ with tf.name_scope(self.embedding.name):
741
+ self.embedding.build(None)
742
+ if getattr(self, "layers", None) is not None:
743
+ for layer in self.layers:
744
+ with tf.name_scope(layer.name):
745
+ layer.build(None)
746
+
747
+
748
+ class TFCvtEncoder(keras.layers.Layer):
749
+ """
750
+ Convolutional Vision Transformer encoder. CVT has 3 stages of encoder blocks with their respective number of layers
751
+ (depth) being 1, 2 and 10.
752
+
753
+ Args:
754
+ config ([`CvtConfig`]): Model configuration class.
755
+ """
756
+
757
+ config_class = CvtConfig
758
+
759
+ def __init__(self, config: CvtConfig, **kwargs):
760
+ super().__init__(**kwargs)
761
+ self.config = config
762
+ self.stages = [
763
+ TFCvtStage(config, stage_idx, name=f"stages.{stage_idx}") for stage_idx in range(len(config.depth))
764
+ ]
765
+
766
+ def call(
767
+ self,
768
+ pixel_values: TFModelInputType,
769
+ output_hidden_states: Optional[bool] = False,
770
+ return_dict: Optional[bool] = True,
771
+ training: Optional[bool] = False,
772
+ ) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]:
773
+ all_hidden_states = () if output_hidden_states else None
774
+ hidden_state = pixel_values
775
+ # When running on CPU, `keras.layers.Conv2D` doesn't support (batch_size, num_channels, height, width)
776
+ # as input format. So change the input format to (batch_size, height, width, num_channels).
777
+ hidden_state = tf.transpose(hidden_state, perm=(0, 2, 3, 1))
778
+
779
+ cls_token = None
780
+ for _, (stage_module) in enumerate(self.stages):
781
+ hidden_state, cls_token = stage_module(hidden_state, training=training)
782
+ if output_hidden_states:
783
+ all_hidden_states = all_hidden_states + (hidden_state,)
784
+
785
+ # Change back to (batch_size, num_channels, height, width) format to have uniformity in the modules
786
+ hidden_state = tf.transpose(hidden_state, perm=(0, 3, 1, 2))
787
+ if output_hidden_states:
788
+ all_hidden_states = tuple([tf.transpose(hs, perm=(0, 3, 1, 2)) for hs in all_hidden_states])
789
+
790
+ if not return_dict:
791
+ return tuple(v for v in [hidden_state, cls_token, all_hidden_states] if v is not None)
792
+
793
+ return TFBaseModelOutputWithCLSToken(
794
+ last_hidden_state=hidden_state,
795
+ cls_token_value=cls_token,
796
+ hidden_states=all_hidden_states,
797
+ )
798
+
799
+ def build(self, input_shape=None):
800
+ if self.built:
801
+ return
802
+ self.built = True
803
+ if getattr(self, "stages", None) is not None:
804
+ for layer in self.stages:
805
+ with tf.name_scope(layer.name):
806
+ layer.build(None)
807
+
808
+
809
+ @keras_serializable
810
+ class TFCvtMainLayer(keras.layers.Layer):
811
+ """Construct the Cvt model."""
812
+
813
+ config_class = CvtConfig
814
+
815
+ def __init__(self, config: CvtConfig, **kwargs):
816
+ super().__init__(**kwargs)
817
+ self.config = config
818
+ self.encoder = TFCvtEncoder(config, name="encoder")
819
+
820
+ @unpack_inputs
821
+ def call(
822
+ self,
823
+ pixel_values: TFModelInputType | None = None,
824
+ output_hidden_states: Optional[bool] = None,
825
+ return_dict: Optional[bool] = None,
826
+ training: Optional[bool] = False,
827
+ ) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]:
828
+ if pixel_values is None:
829
+ raise ValueError("You have to specify pixel_values")
830
+
831
+ encoder_outputs = self.encoder(
832
+ pixel_values,
833
+ output_hidden_states=output_hidden_states,
834
+ return_dict=return_dict,
835
+ training=training,
836
+ )
837
+
838
+ sequence_output = encoder_outputs[0]
839
+
840
+ if not return_dict:
841
+ return (sequence_output,) + encoder_outputs[1:]
842
+
843
+ return TFBaseModelOutputWithCLSToken(
844
+ last_hidden_state=sequence_output,
845
+ cls_token_value=encoder_outputs.cls_token_value,
846
+ hidden_states=encoder_outputs.hidden_states,
847
+ )
848
+
849
+ def build(self, input_shape=None):
850
+ if self.built:
851
+ return
852
+ self.built = True
853
+ if getattr(self, "encoder", None) is not None:
854
+ with tf.name_scope(self.encoder.name):
855
+ self.encoder.build(None)
856
+
857
+
858
+ class TFCvtPreTrainedModel(TFPreTrainedModel):
859
+ """
860
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
861
+ models.
862
+ """
863
+
864
+ config_class = CvtConfig
865
+ base_model_prefix = "cvt"
866
+ main_input_name = "pixel_values"
867
+
868
+
869
+ TFCVT_START_DOCSTRING = r"""
870
+
871
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
872
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
873
+ etc.)
874
+
875
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
876
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
877
+ behavior.
878
+
879
+ <Tip>
880
+
881
+ TF 2.0 models accepts two formats as inputs:
882
+
883
+ - having all inputs as keyword arguments (like PyTorch models), or
884
+ - having all inputs as a list, tuple or dict in the first positional arguments.
885
+
886
+ This second option is useful when using [`keras.Model.fit`] method which currently requires having all the
887
+ tensors in the first argument of the model call function: `model(inputs)`.
888
+
889
+ </Tip>
890
+
891
+ Args:
892
+ config ([`CvtConfig`]): Model configuration class with all the parameters of the model.
893
+ Initializing with a config file does not load the weights associated with the model, only the
894
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
895
+ """
896
+
897
+ TFCVT_INPUTS_DOCSTRING = r"""
898
+ Args:
899
+ pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
900
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`CvtImageProcessor.__call__`]
901
+ for details.
902
+
903
+ output_hidden_states (`bool`, *optional*):
904
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
905
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
906
+ used instead.
907
+ return_dict (`bool`, *optional*):
908
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
909
+ eager mode, in graph mode the value will always be set to True.
910
+ training (`bool`, *optional*, defaults to `False``):
911
+ Whether or not to use the model in training mode (some modules like dropout modules have different
912
+ behaviors between training and evaluation).
913
+ """
914
+
915
+
916
+ @add_start_docstrings(
917
+ "The bare Cvt Model transformer outputting raw hidden-states without any specific head on top.",
918
+ TFCVT_START_DOCSTRING,
919
+ )
920
+ class TFCvtModel(TFCvtPreTrainedModel):
921
+ def __init__(self, config: CvtConfig, *inputs, **kwargs):
922
+ super().__init__(config, *inputs, **kwargs)
923
+
924
+ self.cvt = TFCvtMainLayer(config, name="cvt")
925
+
926
+ @unpack_inputs
927
+ @add_start_docstrings_to_model_forward(TFCVT_INPUTS_DOCSTRING)
928
+ @replace_return_docstrings(output_type=TFBaseModelOutputWithCLSToken, config_class=_CONFIG_FOR_DOC)
929
+ def call(
930
+ self,
931
+ pixel_values: tf.Tensor | None = None,
932
+ output_hidden_states: Optional[bool] = None,
933
+ return_dict: Optional[bool] = None,
934
+ training: Optional[bool] = False,
935
+ ) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]:
936
+ r"""
937
+ Returns:
938
+
939
+ Examples:
940
+
941
+ ```python
942
+ >>> from transformers import AutoImageProcessor, TFCvtModel
943
+ >>> from PIL import Image
944
+ >>> import requests
945
+
946
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
947
+ >>> image = Image.open(requests.get(url, stream=True).raw)
948
+
949
+ >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/cvt-13")
950
+ >>> model = TFCvtModel.from_pretrained("microsoft/cvt-13")
951
+
952
+ >>> inputs = image_processor(images=image, return_tensors="tf")
953
+ >>> outputs = model(**inputs)
954
+ >>> last_hidden_states = outputs.last_hidden_state
955
+ ```"""
956
+
957
+ if pixel_values is None:
958
+ raise ValueError("You have to specify pixel_values")
959
+
960
+ outputs = self.cvt(
961
+ pixel_values=pixel_values,
962
+ output_hidden_states=output_hidden_states,
963
+ return_dict=return_dict,
964
+ training=training,
965
+ )
966
+
967
+ if not return_dict:
968
+ return (outputs[0],) + outputs[1:]
969
+
970
+ return TFBaseModelOutputWithCLSToken(
971
+ last_hidden_state=outputs.last_hidden_state,
972
+ cls_token_value=outputs.cls_token_value,
973
+ hidden_states=outputs.hidden_states,
974
+ )
975
+
976
+ def build(self, input_shape=None):
977
+ if self.built:
978
+ return
979
+ self.built = True
980
+ if getattr(self, "cvt", None) is not None:
981
+ with tf.name_scope(self.cvt.name):
982
+ self.cvt.build(None)
983
+
984
+
985
+ @add_start_docstrings(
986
+ """
987
+ Cvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
988
+ the [CLS] token) e.g. for ImageNet.
989
+ """,
990
+ TFCVT_START_DOCSTRING,
991
+ )
992
+ class TFCvtForImageClassification(TFCvtPreTrainedModel, TFSequenceClassificationLoss):
993
+ def __init__(self, config: CvtConfig, *inputs, **kwargs):
994
+ super().__init__(config, *inputs, **kwargs)
995
+
996
+ self.num_labels = config.num_labels
997
+ self.cvt = TFCvtMainLayer(config, name="cvt")
998
+ # Using same default epsilon as in the original implementation.
999
+ self.layernorm = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm")
1000
+
1001
+ # Classifier head
1002
+ self.classifier = keras.layers.Dense(
1003
+ units=config.num_labels,
1004
+ kernel_initializer=get_initializer(config.initializer_range),
1005
+ use_bias=True,
1006
+ bias_initializer="zeros",
1007
+ name="classifier",
1008
+ )
1009
+ self.config = config
1010
+
1011
+ @unpack_inputs
1012
+ @add_start_docstrings_to_model_forward(TFCVT_INPUTS_DOCSTRING)
1013
+ @replace_return_docstrings(output_type=TFImageClassifierOutputWithNoAttention, config_class=_CONFIG_FOR_DOC)
1014
+ def call(
1015
+ self,
1016
+ pixel_values: tf.Tensor | None = None,
1017
+ labels: tf.Tensor | None = None,
1018
+ output_hidden_states: Optional[bool] = None,
1019
+ return_dict: Optional[bool] = None,
1020
+ training: Optional[bool] = False,
1021
+ ) -> Union[TFImageClassifierOutputWithNoAttention, Tuple[tf.Tensor]]:
1022
+ r"""
1023
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
1024
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1025
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1026
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1027
+
1028
+ Returns:
1029
+
1030
+ Examples:
1031
+
1032
+ ```python
1033
+ >>> from transformers import AutoImageProcessor, TFCvtForImageClassification
1034
+ >>> import tensorflow as tf
1035
+ >>> from PIL import Image
1036
+ >>> import requests
1037
+
1038
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1039
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1040
+
1041
+ >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/cvt-13")
1042
+ >>> model = TFCvtForImageClassification.from_pretrained("microsoft/cvt-13")
1043
+
1044
+ >>> inputs = image_processor(images=image, return_tensors="tf")
1045
+ >>> outputs = model(**inputs)
1046
+ >>> logits = outputs.logits
1047
+ >>> # model predicts one of the 1000 ImageNet classes
1048
+ >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
1049
+ >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
1050
+ ```"""
1051
+
1052
+ outputs = self.cvt(
1053
+ pixel_values,
1054
+ output_hidden_states=output_hidden_states,
1055
+ return_dict=return_dict,
1056
+ training=training,
1057
+ )
1058
+
1059
+ sequence_output = outputs[0]
1060
+ cls_token = outputs[1]
1061
+ if self.config.cls_token[-1]:
1062
+ sequence_output = self.layernorm(cls_token)
1063
+ else:
1064
+ # rearrange "batch_size, num_channels, height, width -> batch_size, (height*width), num_channels"
1065
+ batch_size, num_channels, height, width = shape_list(sequence_output)
1066
+ sequence_output = tf.reshape(sequence_output, shape=(batch_size, num_channels, height * width))
1067
+ sequence_output = tf.transpose(sequence_output, perm=(0, 2, 1))
1068
+ sequence_output = self.layernorm(sequence_output)
1069
+
1070
+ sequence_output_mean = tf.reduce_mean(sequence_output, axis=1)
1071
+ logits = self.classifier(sequence_output_mean)
1072
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
1073
+
1074
+ if not return_dict:
1075
+ output = (logits,) + outputs[2:]
1076
+ return ((loss,) + output) if loss is not None else output
1077
+
1078
+ return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
1079
+
1080
+ def build(self, input_shape=None):
1081
+ if self.built:
1082
+ return
1083
+ self.built = True
1084
+ if getattr(self, "cvt", None) is not None:
1085
+ with tf.name_scope(self.cvt.name):
1086
+ self.cvt.build(None)
1087
+ if getattr(self, "layernorm", None) is not None:
1088
+ with tf.name_scope(self.layernorm.name):
1089
+ self.layernorm.build([None, None, self.config.embed_dim[-1]])
1090
+ if getattr(self, "classifier", None) is not None:
1091
+ if hasattr(self.classifier, "name"):
1092
+ with tf.name_scope(self.classifier.name):
1093
+ self.classifier.build([None, None, self.config.embed_dim[-1]])
1094
+
1095
+
1096
+ __all__ = ["TFCvtForImageClassification", "TFCvtModel", "TFCvtPreTrainedModel"]
.venv/lib/python3.11/site-packages/transformers/models/encoder_decoder/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_encoder_decoder import *
22
+ from .modeling_encoder_decoder import *
23
+ from .modeling_flax_encoder_decoder import *
24
+ from .modeling_tf_encoder_decoder import *
25
+ else:
26
+ import sys
27
+
28
+ _file = globals()["__file__"]
29
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
.venv/lib/python3.11/site-packages/transformers/models/encoder_decoder/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (893 Bytes). View file
 
.venv/lib/python3.11/site-packages/transformers/models/encoder_decoder/__pycache__/configuration_encoder_decoder.cpython-311.pyc ADDED
Binary file (5.09 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/encoder_decoder/__pycache__/modeling_encoder_decoder.cpython-311.pyc ADDED
Binary file (36.2 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/encoder_decoder/__pycache__/modeling_flax_encoder_decoder.cpython-311.pyc ADDED
Binary file (41.9 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/encoder_decoder/__pycache__/modeling_tf_encoder_decoder.cpython-311.pyc ADDED
Binary file (34.4 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/encoder_decoder/configuration_encoder_decoder.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ from ...configuration_utils import PretrainedConfig
19
+ from ...utils import logging
20
+ from ..auto import AutoConfig
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class EncoderDecoderConfig(PretrainedConfig):
27
+ r"""
28
+ [`EncoderDecoderConfig`] is the configuration class to store the configuration of a [`EncoderDecoderModel`]. It is
29
+ used to instantiate an Encoder Decoder model according to the specified arguments, defining the encoder and decoder
30
+ configs.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+ Args:
36
+ kwargs (*optional*):
37
+ Dictionary of keyword arguments. Notably:
38
+
39
+ - **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
40
+ the encoder config.
41
+ - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
42
+ the decoder config.
43
+
44
+ Examples:
45
+
46
+ ```python
47
+ >>> from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel
48
+
49
+ >>> # Initializing a BERT google-bert/bert-base-uncased style configuration
50
+ >>> config_encoder = BertConfig()
51
+ >>> config_decoder = BertConfig()
52
+
53
+ >>> config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
54
+
55
+ >>> # Initializing a Bert2Bert model (with random weights) from the google-bert/bert-base-uncased style configurations
56
+ >>> model = EncoderDecoderModel(config=config)
57
+
58
+ >>> # Accessing the model configuration
59
+ >>> config_encoder = model.config.encoder
60
+ >>> config_decoder = model.config.decoder
61
+ >>> # set decoder config to causal lm
62
+ >>> config_decoder.is_decoder = True
63
+ >>> config_decoder.add_cross_attention = True
64
+
65
+ >>> # Saving the model, including its configuration
66
+ >>> model.save_pretrained("my-model")
67
+
68
+ >>> # loading model and config from pretrained folder
69
+ >>> encoder_decoder_config = EncoderDecoderConfig.from_pretrained("my-model")
70
+ >>> model = EncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config)
71
+ ```"""
72
+
73
+ model_type = "encoder-decoder"
74
+ sub_configs = {"encoder": AutoConfig, "decoder": AutoConfig}
75
+ is_composition = True
76
+
77
+ def __init__(self, **kwargs):
78
+ super().__init__(**kwargs)
79
+ if "encoder" not in kwargs or "decoder" not in kwargs:
80
+ raise ValueError(
81
+ f"A configuraton of type {self.model_type} cannot be instantiated because "
82
+ f"both `encoder` and `decoder` sub-configurations were not passed, only {kwargs}"
83
+ )
84
+ encoder_config = kwargs.pop("encoder")
85
+ encoder_model_type = encoder_config.pop("model_type")
86
+ decoder_config = kwargs.pop("decoder")
87
+ decoder_model_type = decoder_config.pop("model_type")
88
+
89
+ self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config)
90
+ self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config)
91
+ self.is_encoder_decoder = True
92
+
93
+ @classmethod
94
+ def from_encoder_decoder_configs(
95
+ cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
96
+ ) -> PretrainedConfig:
97
+ r"""
98
+ Instantiate a [`EncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model configuration and
99
+ decoder model configuration.
100
+
101
+ Returns:
102
+ [`EncoderDecoderConfig`]: An instance of a configuration object
103
+ """
104
+ logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
105
+ decoder_config.is_decoder = True
106
+ decoder_config.add_cross_attention = True
107
+
108
+ return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
109
+
110
+
111
+ __all__ = ["EncoderDecoderConfig"]
.venv/lib/python3.11/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Classes to support Encoder-Decoder architectures"""
16
+
17
+ import gc
18
+ import inspect
19
+ import os
20
+ import tempfile
21
+ import warnings
22
+ from typing import Optional, Tuple, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+ from torch.nn import CrossEntropyLoss
27
+
28
+ from ...configuration_utils import PretrainedConfig
29
+ from ...generation import GenerationMixin
30
+ from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
31
+ from ...modeling_utils import PreTrainedModel
32
+ from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
33
+ from ..auto.configuration_auto import AutoConfig
34
+ from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM
35
+ from .configuration_encoder_decoder import EncoderDecoderConfig
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ _CONFIG_FOR_DOC = "EncoderDecoderConfig"
41
+
42
+ DEPRECATION_WARNING = (
43
+ "Version v4.12.0 introduces a better way to train encoder-decoder models by computing the loss inside the"
44
+ " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if"
45
+ " fine-tuning a model trained with versions anterior to 4.12.0. The decoder_input_ids are now created based on the"
46
+ " labels, no need to pass them yourself anymore."
47
+ )
48
+
49
+ ENCODER_DECODER_START_DOCSTRING = r"""
50
+ This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
51
+ encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
52
+ [`~AutoModel.from_pretrained`] function and the decoder is loaded via [`~AutoModelForCausalLM.from_pretrained`]
53
+ function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream
54
+ generative task, like summarization.
55
+
56
+ The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
57
+ tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
58
+ Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
59
+ Zhou, Wei Li, Peter J. Liu.
60
+
61
+ After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models
62
+ (see the examples for more information).
63
+
64
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
65
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
66
+ etc.)
67
+
68
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
69
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
70
+ and behavior.
71
+
72
+ Parameters:
73
+ config ([`EncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
74
+ Initializing with a config file does not load the weights associated with the model, only the
75
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
76
+ """
77
+
78
+ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
79
+ Args:
80
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
81
+ Indices of input sequence tokens in the vocabulary.
82
+
83
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
84
+ [`PreTrainedTokenizer.__call__`] for details.
85
+
86
+ [What are input IDs?](../glossary#input-ids)
87
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
88
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
89
+
90
+ - 1 for tokens that are **not masked**,
91
+ - 0 for tokens that are **masked**.
92
+
93
+ [What are attention masks?](../glossary#attention-mask)
94
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
95
+ Indices of decoder input sequence tokens in the vocabulary.
96
+
97
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
98
+ [`PreTrainedTokenizer.__call__`] for details.
99
+
100
+ [What are input IDs?](../glossary#input-ids)
101
+
102
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
103
+ `past_key_values`).
104
+
105
+ For training, `decoder_input_ids` are automatically created by the model by shifting the `labels` to the
106
+ right, replacing -100 by the `pad_token_id` and prepending them with the `decoder_start_token_id`.
107
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
108
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
109
+ be used by default.
110
+ encoder_outputs (`tuple(torch.FloatTensor)`, *optional*):
111
+ This tuple must consist of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
112
+ `last_hidden_state` (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`) is a tensor
113
+ of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the
114
+ decoder.
115
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
116
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
117
+
118
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
119
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
120
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
121
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
122
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
123
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
124
+ model's internal embedding lookup matrix.
125
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
126
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
127
+ representation. This is useful if you want more control over how to convert `decoder_input_ids` indices
128
+ into associated vectors than the model's internal embedding lookup matrix.
129
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
130
+ Labels for computing the masked language modeling loss for the decoder. Indices should be in `[-100, 0,
131
+ ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored
132
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
133
+ use_cache (`bool`, *optional*):
134
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
135
+ `past_key_values`).
136
+ output_attentions (`bool`, *optional*):
137
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
138
+ tensors for more detail.
139
+ output_hidden_states (`bool`, *optional*):
140
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
141
+ more detail.
142
+ return_dict (`bool`, *optional*):
143
+ If set to `True`, the model will return a [`~utils.Seq2SeqLMOutput`] instead of a plain tuple.
144
+ kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
145
+
146
+ - Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
147
+ - With a *decoder_* prefix which will be input as `**decoder_kwargs` for the decoder forward function.
148
+ """
149
+
150
+
151
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
152
+ """
153
+ Shift input ids one token to the right.
154
+ """
155
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
156
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
157
+ if decoder_start_token_id is None:
158
+ raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
159
+ shifted_input_ids[:, 0] = decoder_start_token_id
160
+
161
+ if pad_token_id is None:
162
+ raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
163
+ # replace possible -100 values in labels by `pad_token_id`
164
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
165
+
166
+ return shifted_input_ids
167
+
168
+
169
+ @add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
170
+ class EncoderDecoderModel(PreTrainedModel, GenerationMixin):
171
+ r"""
172
+ [`EncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one
173
+ of the base model classes of the library as encoder and another one as decoder when created with the
174
+ :meth*~transformers.AutoModel.from_pretrained* class method for the encoder and
175
+ :meth*~transformers.AutoModelForCausalLM.from_pretrained* class method for the decoder.
176
+ """
177
+
178
+ config_class = EncoderDecoderConfig
179
+ base_model_prefix = "encoder_decoder"
180
+ main_input_name = "input_ids"
181
+ supports_gradient_checkpointing = True
182
+ _supports_param_buffer_assignment = False
183
+ _supports_flash_attn_2 = True
184
+ _supports_sdpa = True
185
+
186
+ def __init__(
187
+ self,
188
+ config: Optional[PretrainedConfig] = None,
189
+ encoder: Optional[PreTrainedModel] = None,
190
+ decoder: Optional[PreTrainedModel] = None,
191
+ ):
192
+ if config is None and (encoder is None or decoder is None):
193
+ raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
194
+ if config is None:
195
+ config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
196
+ else:
197
+ if not isinstance(config, self.config_class):
198
+ raise ValueError(f"Config: {config} has to be of type {self.config_class}")
199
+
200
+ if config.decoder.cross_attention_hidden_size is not None:
201
+ if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
202
+ raise ValueError(
203
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
204
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
205
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
206
+ " `config.encoder.hidden_size`."
207
+ )
208
+
209
+ # initialize with config
210
+ super().__init__(config)
211
+
212
+ if encoder is None:
213
+ from ..auto.modeling_auto import AutoModel
214
+
215
+ encoder = AutoModel.from_config(config.encoder)
216
+
217
+ if decoder is None:
218
+ from ..auto.modeling_auto import AutoModelForCausalLM
219
+
220
+ decoder = AutoModelForCausalLM.from_config(config.decoder)
221
+
222
+ self.encoder = encoder
223
+ self.decoder = decoder
224
+
225
+ if self.encoder.config.to_dict() != self.config.encoder.to_dict():
226
+ logger.warning(
227
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
228
+ f" {self.config.encoder}"
229
+ )
230
+ if self.decoder.config.to_dict() != self.config.decoder.to_dict():
231
+ logger.warning(
232
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
233
+ f" {self.config.decoder}"
234
+ )
235
+
236
+ # make sure that the individual model's config refers to the shared config
237
+ # so that the updates to the config will be synced
238
+ # update `_attn_implementation` because the attn is set in a deepcopied config within PreTrainedModel
239
+ self.config.encoder._attn_implementation = self.encoder.config._attn_implementation
240
+ self.config.decoder._attn_implementation = self.decoder.config._attn_implementation
241
+ self.encoder.config = self.config.encoder
242
+ self.decoder.config = self.config.decoder
243
+
244
+ # encoder outputs might need to be projected to different dimension for decoder
245
+ if (
246
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
247
+ and self.decoder.config.cross_attention_hidden_size is None
248
+ ):
249
+ self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
250
+
251
+ if self.encoder.get_output_embeddings() is not None:
252
+ raise ValueError(
253
+ f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
254
+ )
255
+
256
+ decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys())
257
+ if "encoder_hidden_states" not in decoder_signature:
258
+ raise ValueError(
259
+ "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
260
+ "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
261
+ )
262
+
263
+ # tie encoder, decoder weights if config set accordingly
264
+ self.tie_weights()
265
+
266
+ def tie_weights(self):
267
+ # tie encoder & decoder if needed
268
+ if self.config.tie_encoder_decoder:
269
+ # tie encoder and decoder base model
270
+ decoder_base_model_prefix = self.decoder.base_model_prefix
271
+ tied_weights = self._tie_encoder_decoder_weights(
272
+ self.encoder,
273
+ self.decoder._modules[decoder_base_model_prefix],
274
+ self.decoder.base_model_prefix,
275
+ "encoder",
276
+ )
277
+ # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
278
+ # attributed not an instance member, therefore modifying it will modify the entire class
279
+ # Leading to issues on subsequent calls by different tests or subsequent calls.
280
+ self._dynamic_tied_weights_keys = tied_weights
281
+
282
+ def get_encoder(self):
283
+ return self.encoder
284
+
285
+ def get_decoder(self):
286
+ return self.decoder
287
+
288
+ def get_input_embeddings(self):
289
+ return self.encoder.get_input_embeddings()
290
+
291
+ def get_output_embeddings(self):
292
+ return self.decoder.get_output_embeddings()
293
+
294
+ def set_output_embeddings(self, new_embeddings):
295
+ return self.decoder.set_output_embeddings(new_embeddings)
296
+
297
+ @classmethod
298
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
299
+ r"""
300
+ Example:
301
+
302
+ ```python
303
+ >>> from transformers import EncoderDecoderModel
304
+
305
+ >>> model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
306
+ ```"""
307
+
308
+ from_tf = kwargs.pop("from_tf", False)
309
+ if from_tf:
310
+ from transformers import TFEncoderDecoderModel
311
+
312
+ # a workaround to load from tensorflow checkpoint
313
+ # Using `_tf_model` won't work, because the weight names in the encoder/decoder of `_tf_model` get
314
+ # extended before saving those components. For example, The name of `_tf_model.encoder.vit` is
315
+ # `[top model name]/encoder/vit`, but the name of `tf_model.encoder.vit` is `[top model name]/vit`. The
316
+ # [top model name] is handled (stripped) by the conversion method, and the former case gets extra `encoder`,
317
+ # which should not occur when we want to save the components alone.
318
+ # There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see
319
+ # https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245
320
+ # (the change in `src/transformers/modeling_tf_utils.py`)
321
+ _tf_model = TFEncoderDecoderModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
322
+ config = _tf_model.config
323
+
324
+ # Using `tf_model` instead
325
+ encoder = _tf_model.encoder.__class__(_tf_model.config.encoder)
326
+ decoder = _tf_model.decoder.__class__(_tf_model.config.decoder)
327
+ # Make sure models are built
328
+ encoder(encoder.dummy_inputs)
329
+ decoder(decoder.dummy_inputs)
330
+
331
+ # Get the variable correspondence between `_tf_model` and `encoder` and `decoder`
332
+ encoder_variables = {}
333
+ for v in encoder.trainable_variables + encoder.non_trainable_variables:
334
+ encoder_variables["/".join(v.name.split("/")[1:])] = v
335
+ decoder_variables = {}
336
+ for v in decoder.trainable_variables + decoder.non_trainable_variables:
337
+ decoder_variables["/".join(v.name.split("/")[1:])] = v
338
+
339
+ _encoder_variables = {}
340
+ for v in _tf_model.encoder.trainable_variables + _tf_model.encoder.non_trainable_variables:
341
+ _encoder_variables["/".join(v.name.split("/")[2:])] = v
342
+ _decoder_variables = {}
343
+ for v in _tf_model.decoder.trainable_variables + _tf_model.decoder.non_trainable_variables:
344
+ _decoder_variables["/".join(v.name.split("/")[2:])] = v
345
+
346
+ # assign weight values to `encoder` and `decoder` from `_tf_model`
347
+ for name, v in encoder_variables.items():
348
+ v.assign(_encoder_variables[name])
349
+ for name, v in decoder_variables.items():
350
+ v.assign(_decoder_variables[name])
351
+
352
+ tf_model = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)
353
+
354
+ # Deal with `enc_to_dec_proj`
355
+ if hasattr(_tf_model, "enc_to_dec_proj"):
356
+ tf_model(tf_model.dummy_inputs)
357
+ tf_model.enc_to_dec_proj.kernel.assign(_tf_model.enc_to_dec_proj.kernel)
358
+ tf_model.enc_to_dec_proj.bias.assign(_tf_model.enc_to_dec_proj.bias)
359
+
360
+ with tempfile.TemporaryDirectory() as tmpdirname:
361
+ encoder_dir = os.path.join(tmpdirname, "encoder")
362
+ decoder_dir = os.path.join(tmpdirname, "decoder")
363
+ tf_model.encoder.save_pretrained(encoder_dir)
364
+ tf_model.decoder.save_pretrained(decoder_dir)
365
+
366
+ if hasattr(tf_model, "enc_to_dec_proj"):
367
+ enc_to_dec_proj_weight = torch.transpose(
368
+ torch.from_numpy(tf_model.enc_to_dec_proj.kernel.numpy()), 1, 0
369
+ )
370
+ enc_to_dec_proj_bias = torch.from_numpy(tf_model.enc_to_dec_proj.bias.numpy())
371
+
372
+ del _tf_model
373
+ del tf_model
374
+ gc.collect()
375
+
376
+ model = EncoderDecoderModel.from_encoder_decoder_pretrained(
377
+ encoder_dir, decoder_dir, encoder_from_tf=True, decoder_from_tf=True
378
+ )
379
+ # This is only for copying some specific attributes of this particular model.
380
+ model.config = config
381
+
382
+ if hasattr(model, "enc_to_dec_proj"):
383
+ model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight.contiguous()
384
+ model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias.contiguous()
385
+
386
+ return model
387
+
388
+ # At the moment fast initialization is not supported for composite models
389
+ if kwargs.get("_fast_init", False):
390
+ logger.warning(
391
+ "Fast initialization is currently not supported for EncoderDecoderModel. "
392
+ "Falling back to slow initialization..."
393
+ )
394
+ kwargs["_fast_init"] = False
395
+
396
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
397
+
398
+ @classmethod
399
+ def from_encoder_decoder_pretrained(
400
+ cls,
401
+ encoder_pretrained_model_name_or_path: str = None,
402
+ decoder_pretrained_model_name_or_path: str = None,
403
+ *model_args,
404
+ **kwargs,
405
+ ) -> PreTrainedModel:
406
+ r"""
407
+ Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
408
+ checkpoints.
409
+
410
+
411
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
412
+ the model, you need to first set it back in training mode with `model.train()`.
413
+
414
+ Params:
415
+ encoder_pretrained_model_name_or_path (`str`, *optional*):
416
+ Information necessary to initiate the encoder. Can be either:
417
+
418
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
419
+ - A path to a *directory* containing model weights saved using
420
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
421
+ - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
422
+ this case, `from_tf` should be set to `True` and a configuration object should be provided as
423
+ `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
424
+ PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
425
+
426
+ decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
427
+ Information necessary to initiate the decoder. Can be either:
428
+
429
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
430
+ - A path to a *directory* containing model weights saved using
431
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
432
+ - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
433
+ this case, `from_tf` should be set to `True` and a configuration object should be provided as
434
+ `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
435
+ PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
436
+
437
+ model_args (remaining positional arguments, *optional*):
438
+ All remaining positional arguments will be passed to the underlying model's `__init__` method.
439
+
440
+ kwargs (remaining dictionary of keyword arguments, *optional*):
441
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
442
+ `output_attentions=True`).
443
+
444
+ - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
445
+ - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
446
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
447
+
448
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
449
+
450
+ Example:
451
+
452
+ ```python
453
+ >>> from transformers import EncoderDecoderModel
454
+
455
+ >>> # initialize a bert2bert from two pretrained BERT models. Note that the cross-attention layers will be randomly initialized
456
+ >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "google-bert/bert-base-uncased")
457
+ >>> # saving model after fine-tuning
458
+ >>> model.save_pretrained("./bert2bert")
459
+ >>> # load fine-tuned model
460
+ >>> model = EncoderDecoderModel.from_pretrained("./bert2bert")
461
+ ```"""
462
+
463
+ kwargs_encoder = {
464
+ argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
465
+ }
466
+
467
+ kwargs_decoder = {
468
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
469
+ }
470
+
471
+ # remove encoder, decoder kwargs from kwargs
472
+ for key in kwargs_encoder.keys():
473
+ del kwargs["encoder_" + key]
474
+ for key in kwargs_decoder.keys():
475
+ del kwargs["decoder_" + key]
476
+
477
+ # Load and initialize the encoder and decoder
478
+ # The distinction between encoder and decoder at the model level is made
479
+ # by the value of the flag `is_decoder` that we need to set correctly.
480
+ encoder = kwargs_encoder.pop("model", None)
481
+ if encoder is None:
482
+ if encoder_pretrained_model_name_or_path is None:
483
+ raise ValueError(
484
+ "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
485
+ "to be defined."
486
+ )
487
+
488
+ if "config" not in kwargs_encoder:
489
+ encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
490
+ encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
491
+ )
492
+
493
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
494
+ logger.info(
495
+ f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
496
+ "from a decoder model. Cross-attention and casual mask are disabled."
497
+ )
498
+ encoder_config.is_decoder = False
499
+ encoder_config.add_cross_attention = False
500
+
501
+ kwargs_encoder["config"] = encoder_config
502
+
503
+ encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
504
+
505
+ decoder = kwargs_decoder.pop("model", None)
506
+ if decoder is None:
507
+ if decoder_pretrained_model_name_or_path is None:
508
+ raise ValueError(
509
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
510
+ "to be defined."
511
+ )
512
+
513
+ if "config" not in kwargs_decoder:
514
+ decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
515
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
516
+ )
517
+
518
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
519
+ logger.info(
520
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
521
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
522
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
523
+ )
524
+ decoder_config.is_decoder = True
525
+ decoder_config.add_cross_attention = True
526
+
527
+ kwargs_decoder["config"] = decoder_config
528
+
529
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
530
+ logger.warning(
531
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
532
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
533
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
534
+ "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
535
+ "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
536
+ )
537
+
538
+ decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
539
+
540
+ # instantiate config with corresponding kwargs
541
+ config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
542
+ return cls(encoder=encoder, decoder=decoder, config=config)
543
+
544
+ @add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING)
545
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
546
+ def forward(
547
+ self,
548
+ input_ids: Optional[torch.LongTensor] = None,
549
+ attention_mask: Optional[torch.FloatTensor] = None,
550
+ decoder_input_ids: Optional[torch.LongTensor] = None,
551
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
552
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
553
+ past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
554
+ inputs_embeds: Optional[torch.FloatTensor] = None,
555
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
556
+ labels: Optional[torch.LongTensor] = None,
557
+ use_cache: Optional[bool] = None,
558
+ output_attentions: Optional[bool] = None,
559
+ output_hidden_states: Optional[bool] = None,
560
+ return_dict: Optional[bool] = None,
561
+ **kwargs,
562
+ ) -> Union[Tuple, Seq2SeqLMOutput]:
563
+ r"""
564
+ Returns:
565
+
566
+ Examples:
567
+
568
+ ```python
569
+ >>> from transformers import EncoderDecoderModel, BertTokenizer
570
+ >>> import torch
571
+
572
+ >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
573
+ >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained(
574
+ ... "google-bert/bert-base-uncased", "google-bert/bert-base-uncased"
575
+ ... ) # initialize Bert2Bert from pre-trained checkpoints
576
+
577
+ >>> # training
578
+ >>> model.config.decoder_start_token_id = tokenizer.cls_token_id
579
+ >>> model.config.pad_token_id = tokenizer.pad_token_id
580
+ >>> model.config.vocab_size = model.config.decoder.vocab_size
581
+
582
+ >>> input_ids = tokenizer("This is a really long text", return_tensors="pt").input_ids
583
+ >>> labels = tokenizer("This is the corresponding summary", return_tensors="pt").input_ids
584
+ >>> outputs = model(input_ids=input_ids, labels=labels)
585
+ >>> loss, logits = outputs.loss, outputs.logits
586
+
587
+ >>> # save and load from pretrained
588
+ >>> model.save_pretrained("bert2bert")
589
+ >>> model = EncoderDecoderModel.from_pretrained("bert2bert")
590
+
591
+ >>> # generation
592
+ >>> generated = model.generate(input_ids)
593
+ ```"""
594
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
595
+
596
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
597
+
598
+ kwargs_decoder = {
599
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
600
+ }
601
+
602
+ if encoder_outputs is None:
603
+ encoder_outputs = self.encoder(
604
+ input_ids=input_ids,
605
+ attention_mask=attention_mask,
606
+ inputs_embeds=inputs_embeds,
607
+ output_attentions=output_attentions,
608
+ output_hidden_states=output_hidden_states,
609
+ return_dict=return_dict,
610
+ **kwargs_encoder,
611
+ )
612
+ elif isinstance(encoder_outputs, tuple):
613
+ encoder_outputs = BaseModelOutput(*encoder_outputs)
614
+
615
+ encoder_hidden_states = encoder_outputs[0]
616
+
617
+ # optionally project encoder_hidden_states
618
+ if (
619
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
620
+ and self.decoder.config.cross_attention_hidden_size is None
621
+ ):
622
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
623
+
624
+ if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
625
+ decoder_input_ids = shift_tokens_right(
626
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
627
+ )
628
+ if decoder_attention_mask is None:
629
+ decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)
630
+
631
+ # Decode
632
+ decoder_outputs = self.decoder(
633
+ input_ids=decoder_input_ids,
634
+ attention_mask=decoder_attention_mask,
635
+ encoder_hidden_states=encoder_hidden_states,
636
+ encoder_attention_mask=attention_mask,
637
+ inputs_embeds=decoder_inputs_embeds,
638
+ output_attentions=output_attentions,
639
+ output_hidden_states=output_hidden_states,
640
+ use_cache=use_cache,
641
+ past_key_values=past_key_values,
642
+ return_dict=return_dict,
643
+ **kwargs_decoder,
644
+ )
645
+
646
+ # Compute loss independent from decoder (as some shift the logits inside them)
647
+ loss = None
648
+ if labels is not None:
649
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
650
+ logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
651
+ loss_fct = CrossEntropyLoss()
652
+ loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
653
+
654
+ if not return_dict:
655
+ if loss is not None:
656
+ return (loss,) + decoder_outputs + encoder_outputs
657
+ else:
658
+ return decoder_outputs + encoder_outputs
659
+
660
+ return Seq2SeqLMOutput(
661
+ loss=loss,
662
+ logits=decoder_outputs.logits,
663
+ past_key_values=decoder_outputs.past_key_values,
664
+ decoder_hidden_states=decoder_outputs.hidden_states,
665
+ decoder_attentions=decoder_outputs.attentions,
666
+ cross_attentions=decoder_outputs.cross_attentions,
667
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
668
+ encoder_hidden_states=encoder_outputs.hidden_states,
669
+ encoder_attentions=encoder_outputs.attentions,
670
+ )
671
+
672
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
673
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
674
+
675
+ def resize_token_embeddings(self, *args, **kwargs):
676
+ raise NotImplementedError(
677
+ "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
678
+ " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
679
+ " model.decoder.resize_token_embeddings(...))"
680
+ )
681
+
682
+ def _reorder_cache(self, past_key_values, beam_idx):
683
+ # apply decoder cache reordering here
684
+ return self.decoder._reorder_cache(past_key_values, beam_idx)
685
+
686
+
687
+ __all__ = ["EncoderDecoderModel"]
.venv/lib/python3.11/site-packages/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py ADDED
@@ -0,0 +1,901 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Classes to support Flax Encoder-Decoder architectures"""
16
+
17
+ import os
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import flax.linen as nn
21
+ import jax
22
+ import jax.numpy as jnp
23
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
24
+ from flax.traverse_util import flatten_dict, unflatten_dict
25
+ from jax import lax
26
+ from jax.random import PRNGKey
27
+
28
+ from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
29
+ from ...modeling_flax_utils import FlaxPreTrainedModel
30
+ from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
31
+ from ..auto.configuration_auto import AutoConfig
32
+ from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM
33
+ from .configuration_encoder_decoder import EncoderDecoderConfig
34
+
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ _CONFIG_FOR_DOC = "EncoderDecoderConfig"
39
+
40
+ ENCODER_DECODER_START_DOCSTRING = r"""
41
+ This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
42
+ encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
43
+ [`~AutoModel.from_pretrained`] function and the decoder is loaded via [`~AutoModelForCausalLM.from_pretrained`]
44
+ function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream
45
+ generative task, like summarization.
46
+
47
+ The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
48
+ tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
49
+ Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
50
+ Zhou, Wei Li, Peter J. Liu.
51
+
52
+ After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models
53
+ (see the examples for more information).
54
+
55
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
56
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
57
+ etc.)
58
+
59
+ This model is also a Flax Linen
60
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
61
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
62
+
63
+ Parameters:
64
+ config ([`EncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
65
+ Initializing with a config file does not load the weights associated with the model, only the
66
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
67
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
68
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
69
+ `jax.numpy.bfloat16` (on TPUs).
70
+
71
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
72
+ specified all the computation will be performed with the given `dtype`.
73
+
74
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
75
+ parameters.**
76
+
77
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
78
+ [`~FlaxPreTrainedModel.to_bf16`].
79
+ """
80
+
81
+ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
82
+ Args:
83
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
84
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
85
+ it.
86
+
87
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
88
+ [`PreTrainedTokenizer.__call__`] for details.
89
+
90
+ [What are input IDs?](../glossary#input-ids)
91
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
92
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
93
+
94
+ - 1 for tokens that are **not masked**,
95
+ - 0 for tokens that are **masked**.
96
+
97
+ [What are attention masks?](../glossary#attention-mask)
98
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
99
+ Indices of decoder input sequence tokens in the vocabulary.
100
+
101
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
102
+ [`PreTrainedTokenizer.__call__`] for details.
103
+
104
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
105
+
106
+ For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
107
+ created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
108
+ and prepending them with the `decoder_start_token_id`.
109
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
110
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
111
+ be used by default.
112
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
113
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
114
+ config.encoder.max_position_embeddings - 1]`.
115
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
116
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
117
+ range `[0, config.decoder.max_position_embeddings - 1]`.
118
+ output_attentions (`bool`, *optional*):
119
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
120
+ tensors for more detail.
121
+ output_hidden_states (`bool`, *optional*):
122
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
123
+ more detail.
124
+ return_dict (`bool`, *optional*):
125
+ If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
126
+ """
127
+
128
+ ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
129
+ Args:
130
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
131
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
132
+ it.
133
+
134
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
135
+ [`PreTrainedTokenizer.__call__`] for details.
136
+
137
+ [What are input IDs?](../glossary#input-ids)
138
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
139
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
140
+
141
+ - 1 for tokens that are **not masked**,
142
+ - 0 for tokens that are **masked**.
143
+
144
+ [What are attention masks?](../glossary#attention-mask)
145
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
146
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
147
+ config.encoder.max_position_embeddings - 1]`.
148
+ output_attentions (`bool`, *optional*):
149
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
150
+ tensors for more detail.
151
+ output_hidden_states (`bool`, *optional*):
152
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
153
+ more detail.
154
+ return_dict (`bool`, *optional*):
155
+ If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple.
156
+ """
157
+
158
+ ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
159
+ Args:
160
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
161
+ Indices of decoder input sequence tokens in the vocabulary.
162
+
163
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
164
+ [`PreTrainedTokenizer.__call__`] for details.
165
+
166
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
167
+
168
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
169
+ `past_key_values`).
170
+
171
+ For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
172
+ created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
173
+ and prepending them with the `decoder_start_token_id`.
174
+ encoder_outputs (`tuple(tuple(jnp.ndarray)`):
175
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
176
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
177
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
178
+ encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
179
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
180
+
181
+ - 1 for tokens that are **not masked**,
182
+ - 0 for tokens that are **masked**.
183
+
184
+ [What are attention masks?](../glossary#attention-mask)
185
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
186
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
187
+ be used by default.
188
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
189
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
190
+ range `[0, config.decoder.max_position_embeddings - 1]`.
191
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
192
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
193
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
194
+ output_attentions (`bool`, *optional*):
195
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
196
+ tensors for more detail.
197
+ output_hidden_states (`bool`, *optional*):
198
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
199
+ more detail.
200
+ return_dict (`bool`, *optional*):
201
+ If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a
202
+ plain tuple.
203
+ """
204
+
205
+
206
+ class FlaxEncoderDecoderModule(nn.Module):
207
+ config: EncoderDecoderConfig
208
+ dtype: jnp.dtype = jnp.float32
209
+
210
+ def setup(self):
211
+ encoder_config = self.config.encoder
212
+ decoder_config = self.config.decoder
213
+
214
+ # Copied from `modeling_hybrid_clip.py` with modifications.
215
+ from ...models.auto.modeling_flax_auto import FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_MAPPING
216
+
217
+ encoder_module = FLAX_MODEL_MAPPING[encoder_config.__class__].module_class
218
+ decoder_module = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING[decoder_config.__class__].module_class
219
+
220
+ self.encoder = encoder_module(encoder_config, dtype=self.dtype)
221
+ self.decoder = decoder_module(decoder_config, dtype=self.dtype)
222
+
223
+ # encoder outputs might need to be projected to different dimension for decoder
224
+ if (
225
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
226
+ and self.decoder.config.cross_attention_hidden_size is None
227
+ ):
228
+ self.enc_to_dec_proj = nn.Dense(
229
+ self.decoder.config.hidden_size,
230
+ kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
231
+ dtype=self.dtype,
232
+ )
233
+ else:
234
+ self.enc_to_dec_proj = None
235
+
236
+ def _get_encoder_module(self):
237
+ return self.encoder
238
+
239
+ def _get_projection_module(self):
240
+ return self.enc_to_dec_proj
241
+
242
+ def _get_decoder_module(self):
243
+ return self.decoder
244
+
245
+ def __call__(
246
+ self,
247
+ input_ids,
248
+ attention_mask,
249
+ decoder_input_ids,
250
+ decoder_attention_mask,
251
+ position_ids,
252
+ decoder_position_ids,
253
+ output_attentions: bool = False,
254
+ output_hidden_states: bool = False,
255
+ return_dict: bool = True,
256
+ deterministic: bool = True,
257
+ ):
258
+ encoder_outputs = self.encoder(
259
+ input_ids=input_ids,
260
+ attention_mask=attention_mask,
261
+ position_ids=position_ids,
262
+ output_attentions=output_attentions,
263
+ output_hidden_states=output_hidden_states,
264
+ return_dict=return_dict,
265
+ deterministic=deterministic,
266
+ )
267
+
268
+ encoder_hidden_states = encoder_outputs[0]
269
+
270
+ # optionally project encoder_hidden_states
271
+ if self.enc_to_dec_proj is not None:
272
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
273
+
274
+ decoder_outputs = self.decoder(
275
+ input_ids=decoder_input_ids,
276
+ attention_mask=decoder_attention_mask,
277
+ position_ids=decoder_position_ids,
278
+ encoder_hidden_states=encoder_hidden_states,
279
+ encoder_attention_mask=attention_mask,
280
+ output_attentions=output_attentions,
281
+ output_hidden_states=output_hidden_states,
282
+ return_dict=return_dict,
283
+ deterministic=deterministic,
284
+ )
285
+
286
+ if not return_dict:
287
+ return decoder_outputs + encoder_outputs
288
+
289
+ return FlaxSeq2SeqLMOutput(
290
+ logits=decoder_outputs.logits,
291
+ decoder_hidden_states=decoder_outputs.hidden_states,
292
+ decoder_attentions=decoder_outputs.attentions,
293
+ cross_attentions=decoder_outputs.cross_attentions,
294
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
295
+ encoder_hidden_states=encoder_outputs.hidden_states,
296
+ encoder_attentions=encoder_outputs.attentions,
297
+ )
298
+
299
+
300
+ @add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
301
+ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
302
+ r"""
303
+ [`FlaxEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with
304
+ the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one as
305
+ decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the
306
+ encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder.
307
+ """
308
+
309
+ config_class = EncoderDecoderConfig
310
+ base_model_prefix = "encoder_decoder"
311
+ module_class = FlaxEncoderDecoderModule
312
+
313
+ def __init__(
314
+ self,
315
+ config: EncoderDecoderConfig,
316
+ input_shape: Optional[Tuple] = None,
317
+ seed: int = 0,
318
+ dtype: jnp.dtype = jnp.float32,
319
+ _do_init: bool = True,
320
+ **kwargs,
321
+ ):
322
+ if input_shape is None:
323
+ input_shape = ((1, 1), (1, 1))
324
+
325
+ if not _do_init:
326
+ raise ValueError(
327
+ "`FlaxEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
328
+ )
329
+
330
+ if config.decoder.cross_attention_hidden_size is not None:
331
+ if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
332
+ raise ValueError(
333
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
334
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
335
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
336
+ " `config.encoder.hidden_size`."
337
+ )
338
+
339
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
340
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
341
+
342
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
343
+ encoder_input_shape, decoder_input_shape = input_shape
344
+
345
+ # init input tensors
346
+ input_ids = jnp.zeros(encoder_input_shape, dtype="i4")
347
+ attention_mask = jnp.ones_like(input_ids)
348
+ decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
349
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
350
+
351
+ batch_size, sequence_length = input_ids.shape
352
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
353
+
354
+ decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
355
+ if not decoder_batch_size == batch_size:
356
+ raise ValueError(
357
+ f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder"
358
+ f" and {decoder_batch_size} for decoder."
359
+ )
360
+ decoder_position_ids = jnp.broadcast_to(
361
+ jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
362
+ )
363
+
364
+ params_rng, dropout_rng = jax.random.split(rng)
365
+ rngs = {"params": params_rng, "dropout": dropout_rng}
366
+
367
+ random_params = self.module.init(
368
+ rngs,
369
+ input_ids,
370
+ attention_mask,
371
+ decoder_input_ids,
372
+ decoder_attention_mask,
373
+ position_ids,
374
+ decoder_position_ids,
375
+ )["params"]
376
+
377
+ if params is not None:
378
+ random_params = flatten_dict(unfreeze(random_params))
379
+ params = flatten_dict(unfreeze(params))
380
+ for missing_key in self._missing_keys:
381
+ params[missing_key] = random_params[missing_key]
382
+ self._missing_keys = set()
383
+ return freeze(unflatten_dict(params))
384
+ else:
385
+ return random_params
386
+
387
+ def init_cache(self, batch_size, max_length, encoder_outputs):
388
+ r"""
389
+ Args:
390
+ batch_size (`int`):
391
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
392
+ max_length (`int`):
393
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
394
+ cache.
395
+ encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
396
+ `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
397
+ `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
398
+ is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
399
+ cross-attention of the decoder.
400
+ """
401
+ # init input variables to retrieve cache
402
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
403
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
404
+ decoder_position_ids = jnp.broadcast_to(
405
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
406
+ )
407
+
408
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
409
+ decoder_module = module._get_decoder_module()
410
+ return decoder_module(
411
+ input_ids=decoder_input_ids,
412
+ attention_mask=decoder_attention_mask,
413
+ position_ids=decoder_position_ids,
414
+ **kwargs,
415
+ )
416
+
417
+ init_variables = self.module.init(
418
+ jax.random.PRNGKey(0),
419
+ decoder_input_ids=decoder_input_ids,
420
+ decoder_attention_mask=decoder_attention_mask,
421
+ decoder_position_ids=decoder_position_ids,
422
+ encoder_hidden_states=encoder_outputs[0],
423
+ init_cache=True,
424
+ method=_decoder_forward, # we only need to call the decoder to init the cache
425
+ )
426
+ return unfreeze(init_variables["cache"])
427
+
428
+ @add_start_docstrings(ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
429
+ @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
430
+ def encode(
431
+ self,
432
+ input_ids: jnp.ndarray,
433
+ attention_mask: Optional[jnp.ndarray] = None,
434
+ position_ids: Optional[jnp.ndarray] = None,
435
+ output_attentions: Optional[bool] = None,
436
+ output_hidden_states: Optional[bool] = None,
437
+ return_dict: Optional[bool] = None,
438
+ train: bool = False,
439
+ params: dict = None,
440
+ dropout_rng: PRNGKey = None,
441
+ ):
442
+ r"""
443
+ Returns:
444
+
445
+ Example:
446
+
447
+ ```python
448
+ >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer
449
+
450
+ >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
451
+ >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")
452
+
453
+ >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
454
+
455
+ >>> text = "My friends are cool but they eat too many carbs."
456
+ >>> input_ids = tokenizer.encode(text, return_tensors="np")
457
+ >>> encoder_outputs = model.encode(input_ids)
458
+ ```"""
459
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
460
+ output_hidden_states = (
461
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
462
+ )
463
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
464
+
465
+ if attention_mask is None:
466
+ attention_mask = jnp.ones_like(input_ids)
467
+ if position_ids is None:
468
+ batch_size, sequence_length = input_ids.shape
469
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
470
+
471
+ # Handle any PRNG if needed
472
+ rngs = {}
473
+ if dropout_rng is not None:
474
+ rngs["dropout"] = dropout_rng
475
+
476
+ def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
477
+ encode_module = module._get_encoder_module()
478
+ return encode_module(input_ids, attention_mask, position_ids, **kwargs)
479
+
480
+ outputs = self.module.apply(
481
+ {"params": params or self.params},
482
+ input_ids=jnp.array(input_ids, dtype="i4"),
483
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
484
+ position_ids=jnp.array(position_ids, dtype="i4"),
485
+ output_attentions=output_attentions,
486
+ output_hidden_states=output_hidden_states,
487
+ return_dict=return_dict,
488
+ deterministic=not train,
489
+ rngs=rngs,
490
+ method=_encoder_forward,
491
+ )
492
+
493
+ if return_dict:
494
+ outputs = FlaxBaseModelOutput(
495
+ last_hidden_state=outputs.last_hidden_state,
496
+ hidden_states=outputs.hidden_states,
497
+ attentions=outputs.attentions,
498
+ )
499
+
500
+ return outputs
501
+
502
+ @add_start_docstrings(ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
503
+ @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
504
+ def decode(
505
+ self,
506
+ decoder_input_ids,
507
+ encoder_outputs,
508
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
509
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
510
+ decoder_position_ids: Optional[jnp.ndarray] = None,
511
+ past_key_values: dict = None,
512
+ output_attentions: Optional[bool] = None,
513
+ output_hidden_states: Optional[bool] = None,
514
+ return_dict: Optional[bool] = None,
515
+ train: bool = False,
516
+ params: dict = None,
517
+ dropout_rng: PRNGKey = None,
518
+ ):
519
+ r"""
520
+ Returns:
521
+
522
+ Example:
523
+
524
+ ```python
525
+ >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer
526
+ >>> import jax.numpy as jnp
527
+
528
+ >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
529
+ >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")
530
+
531
+ >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
532
+
533
+ >>> text = "My friends are cool but they eat too many carbs."
534
+ >>> input_ids = tokenizer.encode(text, max_length=1024, return_tensors="np")
535
+ >>> encoder_outputs = model.encode(input_ids)
536
+
537
+ >>> decoder_start_token_id = model.config.decoder.bos_token_id
538
+ >>> decoder_input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
539
+
540
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
541
+ >>> logits = outputs.logits
542
+ ```"""
543
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
544
+ output_hidden_states = (
545
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
546
+ )
547
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
548
+
549
+ encoder_hidden_states = encoder_outputs[0]
550
+ if encoder_attention_mask is None:
551
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
552
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
553
+
554
+ batch_size, sequence_length = decoder_input_ids.shape
555
+ if decoder_attention_mask is None:
556
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
557
+
558
+ if decoder_position_ids is None:
559
+ if past_key_values is not None:
560
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
561
+
562
+ decoder_position_ids = jnp.broadcast_to(
563
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
564
+ )
565
+
566
+ # Handle any PRNG if needed
567
+ rngs = {}
568
+ if dropout_rng is not None:
569
+ rngs["dropout"] = dropout_rng
570
+
571
+ inputs = {"params": params or self.params}
572
+
573
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
574
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
575
+ # it can be changed by FlaxBartAttention module
576
+ if past_key_values:
577
+ inputs["cache"] = past_key_values
578
+ mutable = ["cache"]
579
+ else:
580
+ mutable = False
581
+
582
+ def _decoder_forward(
583
+ module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
584
+ ):
585
+ projection_module = module._get_projection_module()
586
+ decoder_module = module._get_decoder_module()
587
+
588
+ # optionally project encoder_hidden_states
589
+ if projection_module is not None:
590
+ encoder_hidden_states = projection_module(encoder_hidden_states)
591
+
592
+ return decoder_module(
593
+ decoder_input_ids,
594
+ decoder_attention_mask,
595
+ decoder_position_ids,
596
+ encoder_hidden_states=encoder_hidden_states,
597
+ **kwargs,
598
+ )
599
+
600
+ outputs = self.module.apply(
601
+ inputs,
602
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
603
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
604
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
605
+ encoder_hidden_states=encoder_hidden_states,
606
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
607
+ output_attentions=output_attentions,
608
+ output_hidden_states=output_hidden_states,
609
+ return_dict=return_dict,
610
+ deterministic=not train,
611
+ rngs=rngs,
612
+ mutable=mutable,
613
+ method=_decoder_forward,
614
+ )
615
+
616
+ # add updated cache to model output
617
+ if past_key_values is not None and return_dict:
618
+ outputs, past = outputs
619
+ outputs["past_key_values"] = unfreeze(past["cache"])
620
+ return outputs
621
+ elif past_key_values is not None and not return_dict:
622
+ outputs, past = outputs
623
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
624
+
625
+ return outputs
626
+
627
+ @add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING)
628
+ @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
629
+ def __call__(
630
+ self,
631
+ input_ids: jnp.ndarray,
632
+ attention_mask: Optional[jnp.ndarray] = None,
633
+ decoder_input_ids: Optional[jnp.ndarray] = None,
634
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
635
+ position_ids: Optional[jnp.ndarray] = None,
636
+ decoder_position_ids: Optional[jnp.ndarray] = None,
637
+ output_attentions: Optional[bool] = None,
638
+ output_hidden_states: Optional[bool] = None,
639
+ return_dict: Optional[bool] = None,
640
+ train: bool = False,
641
+ params: dict = None,
642
+ dropout_rng: PRNGKey = None,
643
+ ):
644
+ r"""
645
+ Returns:
646
+
647
+ Examples:
648
+
649
+ ```python
650
+ >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer, GPT2Tokenizer
651
+
652
+ >>> # load a fine-tuned bert2gpt2 model
653
+ >>> model = FlaxEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16")
654
+ >>> # load input & output tokenizer
655
+ >>> tokenizer_input = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
656
+ >>> tokenizer_output = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
657
+
658
+ >>> article = '''Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members
659
+ >>> singing a racist chant. SAE's national chapter suspended the students,
660
+ >>> but University of Oklahoma President David Boren took it a step further,
661
+ >>> saying the university's affiliation with the fraternity is permanently done.'''
662
+
663
+ >>> input_ids = tokenizer_input(article, add_special_tokens=True, return_tensors="np").input_ids
664
+
665
+ >>> # use GPT2's eos_token as the pad as well as eos token
666
+ >>> model.config.eos_token_id = model.config.decoder.eos_token_id
667
+ >>> model.config.pad_token_id = model.config.eos_token_id
668
+
669
+ >>> sequences = model.generate(input_ids, num_beams=4, max_length=12).sequences
670
+
671
+ >>> summary = tokenizer_output.batch_decode(sequences, skip_special_tokens=True)[0]
672
+ >>> assert summary == "SAS Alpha Epsilon suspended Sigma Alpha Epsilon members"
673
+ ```
674
+ """
675
+
676
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
677
+ output_hidden_states = (
678
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
679
+ )
680
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
681
+
682
+ # prepare encoder inputs
683
+ if attention_mask is None:
684
+ attention_mask = jnp.ones_like(input_ids)
685
+ if position_ids is None:
686
+ batch_size, sequence_length = input_ids.shape
687
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
688
+
689
+ # prepare decoder inputs
690
+ if decoder_input_ids is None:
691
+ raise ValueError(
692
+ "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must"
693
+ " be specified as an input argument."
694
+ )
695
+ if decoder_attention_mask is None:
696
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
697
+ if decoder_position_ids is None:
698
+ batch_size, sequence_length = decoder_input_ids.shape
699
+ decoder_position_ids = jnp.broadcast_to(
700
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
701
+ )
702
+
703
+ # Handle any PRNG if needed
704
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
705
+
706
+ return self.module.apply(
707
+ {"params": params or self.params},
708
+ input_ids=jnp.array(input_ids, dtype="i4"),
709
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
710
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
711
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
712
+ position_ids=jnp.array(position_ids, dtype="i4"),
713
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
714
+ output_attentions=output_attentions,
715
+ output_hidden_states=output_hidden_states,
716
+ return_dict=return_dict,
717
+ deterministic=not train,
718
+ rngs=rngs,
719
+ )
720
+
721
+ def prepare_inputs_for_generation(
722
+ self,
723
+ decoder_input_ids,
724
+ max_length,
725
+ attention_mask: Optional[jax.Array] = None,
726
+ decoder_attention_mask: Optional[jax.Array] = None,
727
+ encoder_outputs=None,
728
+ **kwargs,
729
+ ):
730
+ # initializing the cache
731
+ batch_size, seq_length = decoder_input_ids.shape
732
+
733
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
734
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
735
+ # But since the decoder uses a causal mask, those positions are masked anyways.
736
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
737
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
738
+ if decoder_attention_mask is not None:
739
+ decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
740
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
741
+ else:
742
+ decoder_position_ids = jnp.broadcast_to(
743
+ jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
744
+ )
745
+
746
+ return {
747
+ "past_key_values": past_key_values,
748
+ "encoder_outputs": encoder_outputs,
749
+ "encoder_attention_mask": attention_mask,
750
+ "decoder_attention_mask": extended_attention_mask,
751
+ "decoder_position_ids": decoder_position_ids,
752
+ }
753
+
754
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
755
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
756
+ model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
757
+ return model_kwargs
758
+
759
+ @classmethod
760
+ def from_encoder_decoder_pretrained(
761
+ cls,
762
+ encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
763
+ decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
764
+ *model_args,
765
+ **kwargs,
766
+ ) -> FlaxPreTrainedModel:
767
+ r"""
768
+ Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
769
+ checkpoints.
770
+
771
+ Params:
772
+ encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*):
773
+ Information necessary to initiate the encoder. Can be either:
774
+
775
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
776
+ - A path to a *directory* containing model weights saved using
777
+ [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
778
+
779
+ decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`):
780
+ Information necessary to initiate the decoder. Can be either:
781
+
782
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
783
+ - A path to a *directory* containing model weights saved using
784
+ [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
785
+
786
+ model_args (remaining positional arguments, *optional*):
787
+ All remaning positional arguments will be passed to the underlying model's `__init__` method.
788
+
789
+ kwargs (remaining dictionary of keyword arguments, *optional*):
790
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
791
+ `output_attentions=True`).
792
+
793
+ - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
794
+ - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
795
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
796
+
797
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
798
+
799
+ Example:
800
+
801
+ ```python
802
+ >>> from transformers import FlaxEncoderDecoderModel
803
+
804
+ >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
805
+ >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")
806
+ >>> # saving model after fine-tuning
807
+ >>> model.save_pretrained("./bert2gpt2")
808
+ >>> # load fine-tuned model
809
+ >>> model = FlaxEncoderDecoderModel.from_pretrained("./bert2gpt2")
810
+ ```"""
811
+
812
+ kwargs_encoder = {
813
+ argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
814
+ }
815
+
816
+ kwargs_decoder = {
817
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
818
+ }
819
+
820
+ # remove encoder, decoder kwargs from kwargs
821
+ for key in kwargs_encoder.keys():
822
+ del kwargs["encoder_" + key]
823
+ for key in kwargs_decoder.keys():
824
+ del kwargs["decoder_" + key]
825
+
826
+ # Load and initialize the encoder and decoder
827
+ # The distinction between encoder and decoder at the model level is made
828
+ # by the value of the flag `is_decoder` that we need to set correctly.
829
+ encoder = kwargs_encoder.pop("model", None)
830
+ if encoder is None:
831
+ if encoder_pretrained_model_name_or_path is None:
832
+ raise ValueError(
833
+ "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
834
+ "to be defined."
835
+ )
836
+
837
+ if "config" not in kwargs_encoder:
838
+ encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
839
+ encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
840
+ )
841
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
842
+ logger.info(
843
+ f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
844
+ "from a decoder model. Cross-attention and casual mask are disabled."
845
+ )
846
+ encoder_config.is_decoder = False
847
+ encoder_config.add_cross_attention = False
848
+
849
+ kwargs_encoder["config"] = encoder_config
850
+
851
+ encoder = FlaxAutoModel.from_pretrained(
852
+ encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
853
+ )
854
+
855
+ decoder = kwargs_decoder.pop("model", None)
856
+ if decoder is None:
857
+ if decoder_pretrained_model_name_or_path is None:
858
+ raise ValueError(
859
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
860
+ "to be defined."
861
+ )
862
+
863
+ if "config" not in kwargs_decoder:
864
+ decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
865
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
866
+ )
867
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
868
+ logger.info(
869
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
870
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
871
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
872
+ )
873
+ decoder_config.is_decoder = True
874
+ decoder_config.add_cross_attention = True
875
+
876
+ kwargs_decoder["config"] = decoder_config
877
+
878
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
879
+ logger.warning(
880
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
881
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
882
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
883
+ "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
884
+ "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
885
+ )
886
+
887
+ decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
888
+
889
+ # instantiate config with corresponding kwargs
890
+ dtype = kwargs.pop("dtype", jnp.float32)
891
+ config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
892
+
893
+ # init model
894
+ model = cls(config, dtype=dtype)
895
+ model.params["encoder"] = encoder.params
896
+ model.params["decoder"] = decoder.params
897
+
898
+ return model
899
+
900
+
901
+ __all__ = ["FlaxEncoderDecoderModel"]
.venv/lib/python3.11/site-packages/transformers/models/mpt/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_mpt import *
22
+ from .modeling_mpt import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
.venv/lib/python3.11/site-packages/transformers/models/mpt/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (757 Bytes). View file
 
.venv/lib/python3.11/site-packages/transformers/models/mpt/__pycache__/configuration_mpt.cpython-311.pyc ADDED
Binary file (10.9 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/mpt/__pycache__/modeling_mpt.cpython-311.pyc ADDED
Binary file (45.3 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/mpt/configuration_mpt.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 HuggingFace Inc. team and MosaicML NLP team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Mpt configuration"""
16
+
17
+ from typing import TYPE_CHECKING, Optional, Union
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ pass
22
+
23
+ from ...configuration_utils import PretrainedConfig
24
+ from ...utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class MptAttentionConfig(PretrainedConfig):
31
+ """
32
+ This is the configuration class to store the configuration of a [`MptAttention`] class. It is used to instantiate
33
+ attention layers according to the specified arguments, defining the layers architecture. Instantiating a
34
+ configuration with the defaults will yield a similar configuration to that of the MPT
35
+ [mosaicml/mpt-7b](https://huggingface.co/mosaicml/mpt-7b) architecture. Most of the arguments are kept for backward
36
+ compatibility with previous MPT models that are hosted on the Hub (previously with `trust_remote_code=True`).
37
+
38
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39
+ documentation from [`PretrainedConfig`] for more information.
40
+
41
+ Args:
42
+ attn_type (`str`, *optional*, defaults to `"multihead_attention"`):
43
+ type of attention to use. Options: `"multihead_attention"`, `"multiquery_attention"`.
44
+ attn_pdrop (`float`, *optional*, defaults to `0.0`):
45
+ The dropout probability for the attention layers.
46
+ attn_impl (`str`, *optional*, defaults to `"torch"`):
47
+ The attention implementation to use. One of `"torch"`, `"flash"`, or `"triton"`.
48
+ clip_qkv (`float`, *optional*):
49
+ If not `None`, clip the queries, keys, and values in the attention layer to this value.
50
+ softmax_scale (`float`, *optional*):
51
+ If not `None`, scale the softmax in the attention layer by this value. If `None`, will default to
52
+ `1/sqrt(hidden_size)`.
53
+ prefix_lm (`bool`, *optional*, defaults to `False`):
54
+ Whether the model should operate as a Prefix LM. This requires passing an extra `prefix_mask` argument
55
+ which indicates which tokens belong to the prefix. Tokens in the prefix can attend to one another
56
+ bi-directionally. Tokens outside the prefix use causal attention.
57
+ qk_ln (`bool`, *optional*, defaults to `False`):
58
+ Whether to apply layer normalization to the queries and keys in the attention layer.
59
+ attn_uses_sequence_id (`bool`, *optional*, defaults to `False`):
60
+ Whether to restrict attention to tokens that have the same token_type_ids. When the model is in `train`
61
+ mode, this requires passing an extra *token_type_ids* argument which indicates which sub-sequence each
62
+ token belongs to. Defaults to `False` meaning any provided *token_type_ids* will be ignored.
63
+ alibi (`bool`, *optional*, defaults to `True`):
64
+ Whether or not to use the alibi bias instead of positional embedding.
65
+ alibi_bias_max (`int`, *optional*, defaults to 8):
66
+ The maximum value of the alibi bias.
67
+ """
68
+
69
+ base_config_key = "attn_config"
70
+
71
+ def __init__(
72
+ self,
73
+ attn_type="multihead_attention",
74
+ attn_pdrop=0,
75
+ attn_impl="torch",
76
+ clip_qkv=None,
77
+ softmax_scale=None,
78
+ prefix_lm=False,
79
+ qk_ln=False,
80
+ attn_uses_sequence_id=False,
81
+ alibi=True,
82
+ alibi_bias_max=8,
83
+ **kwargs,
84
+ ):
85
+ super().__init__()
86
+ self.attn_type = attn_type
87
+ self.attn_pdrop = attn_pdrop
88
+ self.attn_impl = attn_impl
89
+ self.clip_qkv = clip_qkv
90
+ self.softmax_scale = softmax_scale
91
+ self.prefix_lm = prefix_lm
92
+ self.attn_uses_sequence_id = attn_uses_sequence_id
93
+ self.alibi = alibi
94
+ self.qk_ln = qk_ln
95
+ self.alibi_bias_max = alibi_bias_max
96
+
97
+ if attn_type not in ["multihead_attention", "multiquery_attention"]:
98
+ raise ValueError(
99
+ f"`attn_type` has to be either `multihead_attention` or `multiquery_attention`. Received: {attn_type}"
100
+ )
101
+
102
+
103
+ class MptConfig(PretrainedConfig):
104
+ """
105
+ This is the configuration class to store the configuration of a [`MptModel`]. It is used to instantiate a Mpt model
106
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
107
+ defaults will yield a similar configuration to the Mpt-7b architecture
108
+ [mosaicml/mpt-7b](https://huggingface.co/mosaicml/mpt-7b).
109
+
110
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
111
+ documentation from [`PretrainedConfig`] for more information.
112
+
113
+
114
+ Args:
115
+ d_model (`int`, *optional*, defaults to 2048):
116
+ Dimensionality of the embeddings and hidden states.
117
+ n_heads (`int`, *optional*, defaults to 16):
118
+ Number of attention heads for each attention layer in the Transformer encoder.
119
+ n_layers (`int`, *optional*, defaults to 24):
120
+ Number of hidden layers in the Transformer encoder.
121
+ expansion_ratio (`int`, *optional*, defaults to 4):
122
+ The ratio of the up/down scale in the MLP.
123
+ max_seq_len (`int`, *optional*, defaults to 2048):
124
+ The maximum sequence length of the model.
125
+ vocab_size (`int`, *optional*, defaults to 50368):
126
+ Vocabulary size of the Mpt model. Defines the maximum number of different tokens that can be represented by
127
+ the `inputs_ids` passed when calling [`MptModel`]. Check [this
128
+ discussion](https://huggingface.co/bigscience/mpt/discussions/120#633d28389addb8530b406c2a) on how the
129
+ `vocab_size` has been defined.
130
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
131
+ The dropout probability applied to the attention output before combining with residual.
132
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
133
+ The epsilon to use in the layer normalization layers.
134
+ emb_pdrop (`float`, *optional*, defaults to 0.0):
135
+ The dropout probability for the embedding layer.
136
+ learned_pos_emb (`bool`, *optional*, defaults to `True`):
137
+ Whether to use learned positional embeddings.
138
+ attn_config (`dict`, *optional*):
139
+ A dictionary used to configure the model's attention module.
140
+ init_device (`str`, *optional*, defaults to `"cpu"`):
141
+ The device to use for parameter initialization. Defined for backward compatibility
142
+ logit_scale (`float`, *optional*):
143
+ If not None, scale the logits by this value.
144
+ no_bias (`bool`, *optional*, defaults to `True`):
145
+ Whether to use bias in all linear layers.
146
+ verbose (`int`, *optional*, defaults to 0):
147
+ The verbosity level to use for logging. Used in the previous versions of MPT models for logging. This
148
+ argument is deprecated.
149
+ embedding_fraction (`float`, *optional*, defaults to 1.0):
150
+ The fraction to scale the gradients of the embedding layer by.
151
+ norm_type (`str`, *optional*, defaults to `"low_precision_layernorm"`):
152
+ Type of layer norm to use. All MPT models uses the same layer norm implementation. Defined for backward
153
+ compatibility.
154
+ use_cache (`bool`, *optional*, defaults to `False`):
155
+ Whether or not the model should return the last key/values attentions (not used by all models).
156
+ initializer_range (`float`, *optional*, defaults to 0.02):
157
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
158
+
159
+ Example:
160
+
161
+ ```python
162
+ >>> from transformers import MptConfig, MptModel
163
+
164
+ >>> # Initializing a Mpt configuration
165
+ >>> configuration = MptConfig()
166
+
167
+ >>> # Initializing a model (with random weights) from the configuration
168
+ >>> model = MptModel(configuration)
169
+
170
+ >>> # Accessing the model configuration
171
+ >>> configuration = model.config
172
+ ```
173
+ """
174
+
175
+ model_type = "mpt"
176
+ sub_configs = {"attn_config": MptAttentionConfig}
177
+ attribute_map = {
178
+ "num_attention_heads": "n_heads",
179
+ "hidden_size": "d_model",
180
+ "num_hidden_layers": "n_layers",
181
+ }
182
+
183
+ def __init__(
184
+ self,
185
+ d_model: int = 2048,
186
+ n_heads: int = 16,
187
+ n_layers: int = 24,
188
+ expansion_ratio: int = 4,
189
+ max_seq_len: int = 2048,
190
+ vocab_size: int = 50368,
191
+ resid_pdrop: float = 0.0,
192
+ layer_norm_epsilon: float = 1e-5,
193
+ emb_pdrop: float = 0.0,
194
+ learned_pos_emb: bool = True,
195
+ attn_config: MptAttentionConfig = None,
196
+ init_device: str = "cpu",
197
+ logit_scale: Optional[Union[float, str]] = None,
198
+ no_bias: bool = True,
199
+ verbose: int = 0,
200
+ embedding_fraction: float = 1.0,
201
+ norm_type: str = "low_precision_layernorm",
202
+ use_cache: bool = False,
203
+ initializer_range=0.02,
204
+ **kwargs,
205
+ ):
206
+ if attn_config is None:
207
+ self.attn_config = MptAttentionConfig()
208
+ elif isinstance(attn_config, dict):
209
+ self.attn_config = MptAttentionConfig(**attn_config)
210
+ else:
211
+ self.attn_config = attn_config
212
+ self.d_model = d_model
213
+ self.n_heads = n_heads
214
+ self.n_layers = n_layers
215
+ self.expansion_ratio = expansion_ratio
216
+ self.max_seq_len = max_seq_len
217
+ self.vocab_size = vocab_size
218
+ self.resid_pdrop = resid_pdrop
219
+ self.emb_pdrop = emb_pdrop
220
+ self.learned_pos_emb = learned_pos_emb
221
+ self.init_device = init_device
222
+ self.logit_scale = logit_scale
223
+ self.no_bias = no_bias
224
+ self.verbose = verbose
225
+ self.embedding_fraction = embedding_fraction
226
+ self.norm_type = norm_type
227
+ self.layer_norm_epsilon = layer_norm_epsilon
228
+ self.use_cache = use_cache
229
+ self.initializer_range = initializer_range
230
+ super().__init__(**kwargs)
231
+
232
+
233
+ __all__ = ["MptConfig"]
.venv/lib/python3.11/site-packages/transformers/models/mpt/modeling_mpt.py ADDED
@@ -0,0 +1,917 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 HuggingFace Inc. team and MosaicML NLP team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch MPT model."""
16
+
17
+ import math
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
24
+ from torch.nn import functional as F
25
+
26
+ from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
27
+ from ...generation import GenerationMixin
28
+ from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
29
+ from ...modeling_outputs import (
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ QuestionAnsweringModelOutput,
33
+ SequenceClassifierOutputWithPast,
34
+ TokenClassifierOutput,
35
+ )
36
+ from ...modeling_utils import PreTrainedModel
37
+ from ...utils import logging
38
+ from .configuration_mpt import MptConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ _CHECKPOINT_FOR_DOC = "mosaicml/mpt-7b"
44
+ _CONFIG_FOR_DOC = "MptConfig"
45
+
46
+
47
+ def build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8, device=None):
48
+ r"""
49
+ Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
50
+ relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
51
+ the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
52
+ https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292
53
+ """
54
+ alibi = torch.arange(1 - sequence_length, 1, dtype=torch.int32, device=device).view(1, 1, 1, sequence_length)
55
+ num_heads_power_of_2 = 2 ** math.ceil(math.log2(num_heads))
56
+
57
+ base = torch.arange(1, num_heads_power_of_2 + 1, dtype=torch.int64, device=device).float()
58
+ base = base * (alibi_bias_max / num_heads_power_of_2)
59
+
60
+ slopes = 1.0 / torch.pow(2, base)
61
+ slopes = slopes.view(1, num_heads_power_of_2, 1, 1)
62
+
63
+ if num_heads_power_of_2 != num_heads:
64
+ slopes = torch.concat([slopes[:, 1::2, ...], slopes[:, ::2, ...]], dim=1)[:, :num_heads, ...]
65
+
66
+ alibi = alibi * slopes
67
+ return alibi.squeeze(0)
68
+
69
+
70
+ class MptAttention(nn.Module):
71
+ """Multi-head self attention.
72
+ Using torch or triton attention implemetation enables user to also use additive bias.
73
+ """
74
+
75
+ def __init__(self, config: MptConfig):
76
+ super().__init__()
77
+ self.hidden_size = config.hidden_size
78
+ self.n_heads = config.n_heads
79
+ self.max_seq_length = config.max_seq_len
80
+ self.head_dim = self.hidden_size // self.n_heads
81
+ self.softmax_scale = config.attn_config.softmax_scale
82
+ if self.softmax_scale is None:
83
+ self.softmax_scale = 1 / math.sqrt(self.hidden_size / self.n_heads)
84
+
85
+ self.attn_dropout_p = config.attn_config.attn_pdrop
86
+ self.clip_qkv = config.attn_config.clip_qkv
87
+ self.Wqkv = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
88
+ self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
89
+
90
+ def forward(
91
+ self,
92
+ hidden_states: torch.Tensor,
93
+ position_bias: torch.Tensor,
94
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
95
+ attention_mask: Optional[torch.Tensor] = None,
96
+ ):
97
+ batch_size, seq_length = hidden_states.shape[:2]
98
+
99
+ mixed_qkv = self.Wqkv(hidden_states)
100
+ if self.clip_qkv:
101
+ mixed_qkv = mixed_qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
102
+
103
+ query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2)
104
+ query_states = query_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
105
+ key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
106
+ value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
107
+
108
+ if past_key_value is not None:
109
+ if len(past_key_value) != 0:
110
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
111
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
112
+ past_key_value = (key_states, value_states)
113
+ else:
114
+ past_key_value = (key_states, value_states)
115
+
116
+ attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale
117
+
118
+ query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2]
119
+
120
+ if position_bias is not None:
121
+ if len(position_bias.shape) != 3:
122
+ raise ValueError(f"Expecting position_bias shape to be 3 dimensions, got {len(position_bias.shape)}")
123
+ key_length = key_states.shape[-2]
124
+
125
+ position_bias_query_index = max(0, position_bias.size(1) - query_length)
126
+ position_bias_key_index = max(0, position_bias.size(2) - key_length)
127
+
128
+ position_bias = position_bias[:, position_bias_query_index:, position_bias_key_index:]
129
+
130
+ attention_scores = attention_scores + position_bias
131
+
132
+ if attention_mask is not None:
133
+ attention_scores = attention_scores.masked_fill(attention_mask, torch.finfo(query_states.dtype).min)
134
+
135
+ # (batch_size, n_heads, seq_length, key_length)
136
+ attn_weights = nn.functional.softmax(attention_scores.float(), dim=-1).to(value_states.dtype)
137
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attn_dropout_p, training=self.training)
138
+
139
+ context_states = torch.matmul(attn_weights, value_states)
140
+ context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1)
141
+ attn_output = self.out_proj(context_states)
142
+
143
+ return attn_output, attn_weights, past_key_value
144
+
145
+
146
+ class MptMLP(nn.Module):
147
+ def __init__(self, config: MptConfig):
148
+ super().__init__()
149
+ hidden_size = config.hidden_size
150
+
151
+ self.up_proj = nn.Linear(hidden_size, 4 * hidden_size, bias=False)
152
+ self.act = nn.GELU(approximate="none")
153
+ self.down_proj = nn.Linear(4 * hidden_size, hidden_size, bias=False)
154
+ self.hidden_dropout = config.attn_config.attn_pdrop
155
+
156
+ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
157
+ hidden_states = self.act(self.up_proj(hidden_states))
158
+
159
+ intermediate_output = self.down_proj(hidden_states)
160
+
161
+ output = F.dropout(intermediate_output, p=self.hidden_dropout, training=self.training)
162
+ output = output + residual
163
+
164
+ return output
165
+
166
+
167
+ class MptBlock(nn.Module):
168
+ def __init__(self, config: MptConfig):
169
+ super().__init__()
170
+ hidden_size = config.hidden_size
171
+
172
+ self.norm_1 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
173
+ # backward compatibility with weights on the Hub
174
+ self.norm_1.bias = None
175
+
176
+ self.num_heads = config.n_heads
177
+ self.attn = MptAttention(config)
178
+
179
+ self.norm_2 = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
180
+ # backward compatibility with weights on the Hub
181
+ self.norm_2.bias = None
182
+
183
+ self.ffn = MptMLP(config)
184
+
185
+ self.dropout_rate = config.attn_config.attn_pdrop
186
+ self.resid_attn_dropout = nn.Dropout(self.dropout_rate)
187
+
188
+ def forward(
189
+ self,
190
+ hidden_states: torch.Tensor,
191
+ position_bias: torch.Tensor,
192
+ attention_mask: torch.Tensor,
193
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
194
+ use_cache: bool = False,
195
+ output_attentions: bool = False,
196
+ ):
197
+ # hidden_states: [batch_size, seq_length, hidden_size]
198
+ # Layer norm at the beginning of the transformer layer.
199
+ layernorm_output = self.norm_1(hidden_states)
200
+
201
+ residual = hidden_states
202
+
203
+ # Self attention.
204
+ attn_outputs, attn_weights, past_key_value = self.attn(
205
+ layernorm_output,
206
+ position_bias=position_bias,
207
+ attention_mask=attention_mask,
208
+ past_key_value=layer_past,
209
+ )
210
+
211
+ hidden_states = self.resid_attn_dropout(attn_outputs) + residual
212
+
213
+ layernorm_output = self.norm_2(hidden_states)
214
+
215
+ # Get residual
216
+ residual = hidden_states
217
+
218
+ # MLP.
219
+ output = self.ffn(layernorm_output, residual)
220
+ outputs = (output,)
221
+
222
+ if use_cache:
223
+ outputs += (past_key_value,)
224
+
225
+ if output_attentions:
226
+ outputs += (attn_weights,)
227
+
228
+ return outputs # hidden_states, present, attentions
229
+
230
+
231
+ class MptPreTrainedModel(PreTrainedModel):
232
+ config_class = MptConfig
233
+ base_model_prefix = "transformer"
234
+ supports_gradient_checkpointing = True
235
+ _no_split_modules = ["MptBlock"]
236
+ _keys_to_ignore_on_load_missing = [r"lm_head.*."]
237
+
238
+ def __init__(self, *inputs, **kwargs):
239
+ super().__init__(*inputs, **kwargs)
240
+
241
+ def _init_weights(self, module: nn.Module):
242
+ """Initialize the weights."""
243
+ if isinstance(module, nn.Linear):
244
+ # Slightly different from the TF version which uses truncated_normal for initialization
245
+ # cf https://github.com/pytorch/pytorch/pull/5617
246
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
247
+ if module.bias is not None:
248
+ module.bias.data.zero_()
249
+ elif isinstance(module, nn.Embedding):
250
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
251
+ if module.padding_idx is not None:
252
+ module.weight.data[module.padding_idx].zero_()
253
+ elif isinstance(module, LayerNorm):
254
+ if module.bias is not None:
255
+ module.bias.data.zero_()
256
+ module.weight.data.fill_(1.0)
257
+
258
+ @staticmethod
259
+ def _convert_to_mpt_cache(
260
+ past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
261
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
262
+ """
263
+ Converts the cache to the format expected by Mpt, i.e. to tuple(tuple([batch_size * num_heads, ...]))
264
+ """
265
+ batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
266
+ batch_size_times_num_heads = batch_size * num_heads
267
+ # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
268
+ # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
269
+ return tuple(
270
+ (
271
+ layer_past[0].reshape(batch_size_times_num_heads, head_dim, seq_length),
272
+ layer_past[1].reshape(batch_size_times_num_heads, seq_length, head_dim),
273
+ )
274
+ for layer_past in past_key_value
275
+ )
276
+
277
+
278
+ MPT_START_DOCSTRING = r"""
279
+
280
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
281
+ library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
282
+
283
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
284
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
285
+ and behavior.
286
+
287
+ Parameters:
288
+ config ([`MptConfig`]): Model configuration class with all the parameters of the model.
289
+ Initializing with a config file does not load the weights associated with the model, only the
290
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
291
+ """
292
+
293
+ MPT_INPUTS_DOCSTRING = r"""
294
+ Args:
295
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
296
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
297
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
298
+
299
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
300
+ `input_ids`.
301
+
302
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
303
+ [`PreTrainedTokenizer.__call__`] for details.
304
+
305
+ [What are input IDs?](../glossary#input-ids)
306
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
307
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
308
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
309
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
310
+
311
+ Each element of `past_key_values` is a tuple (past_key, past_value):
312
+ - past_key: [batch_size * num_heads, head_dim, kv_length]
313
+ - past_value: [batch_size * num_heads, kv_length, head_dim]
314
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
315
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
316
+
317
+ - 1 for tokens that are **not masked**,
318
+ - 0 for tokens that are **masked**.
319
+
320
+ [What are attention masks?](../glossary#attention-mask)
321
+
322
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
323
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
324
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
325
+ model's internal embedding lookup matrix.
326
+
327
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
328
+ `past_key_values`).
329
+ use_cache (`bool`, *optional*):
330
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
331
+ `past_key_values`).
332
+ output_attentions (`bool`, *optional*):
333
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
334
+ tensors for more detail.
335
+ output_hidden_states (`bool`, *optional*):
336
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
337
+ more detail.
338
+ return_dict (`bool`, *optional*):
339
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
340
+ """
341
+
342
+
343
+ @add_start_docstrings(
344
+ "The bare Mpt Model transformer outputting raw hidden-states without any specific head on top.",
345
+ MPT_START_DOCSTRING,
346
+ )
347
+ class MptModel(MptPreTrainedModel):
348
+ def __init__(self, config: MptConfig):
349
+ super().__init__(config)
350
+
351
+ self.hidden_size = config.hidden_size
352
+ self.num_heads = config.n_heads
353
+
354
+ # Embedding + LN Embedding
355
+ self.wte = nn.Embedding(config.vocab_size, self.hidden_size)
356
+
357
+ # Transformer blocks
358
+ self.blocks = nn.ModuleList([MptBlock(config) for _ in range(config.n_layers)])
359
+
360
+ # Final Layer Norm
361
+ self.norm_f = LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
362
+ # backward compatibility with weights on the Hub
363
+ self.norm_f.bias = None
364
+
365
+ self.gradient_checkpointing = False
366
+
367
+ # Initialize weights and apply final processing
368
+ self.post_init()
369
+
370
+ def get_input_embeddings(self):
371
+ return self.wte
372
+
373
+ def build_mpt_alibi_tensor(self, num_heads, sequence_length, alibi_bias_max=8, device=None):
374
+ return build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max, device)
375
+
376
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
377
+ self.wte = new_embeddings
378
+
379
+ @add_start_docstrings_to_model_forward(MPT_INPUTS_DOCSTRING)
380
+ @add_code_sample_docstrings(
381
+ checkpoint=_CHECKPOINT_FOR_DOC,
382
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
383
+ config_class=_CONFIG_FOR_DOC,
384
+ )
385
+ def forward(
386
+ self,
387
+ input_ids: Optional[torch.LongTensor] = None,
388
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
389
+ attention_mask: Optional[torch.Tensor] = None,
390
+ inputs_embeds: Optional[torch.LongTensor] = None,
391
+ use_cache: Optional[bool] = None,
392
+ output_attentions: Optional[bool] = None,
393
+ output_hidden_states: Optional[bool] = None,
394
+ return_dict: Optional[bool] = None,
395
+ **kwargs, # NOOP kwargs, for now
396
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
397
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
398
+ output_hidden_states = (
399
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
400
+ )
401
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
402
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
403
+
404
+ if input_ids is not None and inputs_embeds is not None:
405
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
406
+ elif input_ids is not None:
407
+ batch_size, seq_length = input_ids.shape
408
+ elif inputs_embeds is not None:
409
+ batch_size, seq_length, _ = inputs_embeds.shape
410
+ else:
411
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
412
+
413
+ if past_key_values is None:
414
+ past_key_values = tuple([None] * len(self.blocks))
415
+
416
+ if inputs_embeds is None:
417
+ inputs_embeds = self.wte(input_ids)
418
+
419
+ hidden_states = inputs_embeds
420
+
421
+ presents = () if use_cache else None
422
+ all_self_attentions = () if output_attentions else None
423
+ all_hidden_states = () if output_hidden_states else None
424
+
425
+ if self.gradient_checkpointing and self.training:
426
+ if use_cache:
427
+ logger.warning_once(
428
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
429
+ )
430
+ use_cache = False
431
+
432
+ # Compute alibi tensor: check build_alibi_tensor documentation
433
+ seq_length_with_past = seq_length
434
+ past_key_values_length = 0
435
+ if past_key_values[0] is not None:
436
+ past_key_values_length = past_key_values[0][0].shape[2]
437
+ seq_length_with_past = seq_length_with_past + past_key_values_length
438
+ if attention_mask is None:
439
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
440
+ else:
441
+ attention_mask = attention_mask.to(hidden_states.device)
442
+
443
+ alibi = self.build_mpt_alibi_tensor(self.num_heads, self.config.max_seq_len, device=hidden_states.device)
444
+
445
+ causal_mask = _prepare_4d_causal_attention_mask(
446
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
447
+ )
448
+ causal_mask = causal_mask.bool()
449
+
450
+ for block, layer_past in zip(self.blocks, past_key_values):
451
+ if output_hidden_states:
452
+ all_hidden_states = all_hidden_states + (hidden_states,)
453
+
454
+ if self.gradient_checkpointing and self.training:
455
+ outputs = self._gradient_checkpointing_func(
456
+ block.__call__,
457
+ hidden_states,
458
+ alibi,
459
+ causal_mask,
460
+ layer_past,
461
+ use_cache,
462
+ output_attentions,
463
+ )
464
+ else:
465
+ outputs = block(
466
+ hidden_states,
467
+ layer_past=layer_past,
468
+ attention_mask=causal_mask,
469
+ use_cache=use_cache,
470
+ output_attentions=output_attentions,
471
+ position_bias=alibi,
472
+ )
473
+
474
+ hidden_states = outputs[0]
475
+ if use_cache is True:
476
+ presents = presents + (outputs[1],)
477
+
478
+ if output_attentions:
479
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
480
+
481
+ # Add last hidden state
482
+ hidden_states = self.norm_f(hidden_states)
483
+
484
+ if output_hidden_states:
485
+ all_hidden_states = all_hidden_states + (hidden_states,)
486
+
487
+ if not return_dict:
488
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
489
+
490
+ return BaseModelOutputWithPastAndCrossAttentions(
491
+ last_hidden_state=hidden_states,
492
+ past_key_values=presents,
493
+ hidden_states=all_hidden_states,
494
+ attentions=all_self_attentions,
495
+ )
496
+
497
+
498
+ @add_start_docstrings(
499
+ """
500
+ The MPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
501
+ embeddings).
502
+ """,
503
+ MPT_START_DOCSTRING,
504
+ )
505
+ class MptForCausalLM(MptPreTrainedModel, GenerationMixin):
506
+ _tied_weights_keys = ["lm_head.weight"]
507
+
508
+ def __init__(self, config: MptConfig):
509
+ super().__init__(config)
510
+ self.transformer = MptModel(config)
511
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
512
+
513
+ # Initialize weights and apply final processing
514
+ self.post_init()
515
+
516
+ def get_output_embeddings(self):
517
+ return self.lm_head
518
+
519
+ def set_output_embeddings(self, new_embeddings: torch.Tensor):
520
+ self.lm_head = new_embeddings
521
+
522
+ @add_start_docstrings_to_model_forward(MPT_INPUTS_DOCSTRING)
523
+ @add_code_sample_docstrings(
524
+ checkpoint=_CHECKPOINT_FOR_DOC,
525
+ output_type=CausalLMOutputWithCrossAttentions,
526
+ config_class=_CONFIG_FOR_DOC,
527
+ )
528
+ def forward(
529
+ self,
530
+ input_ids: Optional[torch.LongTensor] = None,
531
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
532
+ attention_mask: Optional[torch.Tensor] = None,
533
+ inputs_embeds: Optional[torch.Tensor] = None,
534
+ labels: Optional[torch.Tensor] = None,
535
+ use_cache: Optional[bool] = None,
536
+ output_attentions: Optional[bool] = None,
537
+ output_hidden_states: Optional[bool] = None,
538
+ return_dict: Optional[bool] = None,
539
+ **kwargs,
540
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
541
+ r"""
542
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
543
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
544
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
545
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
546
+ """
547
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
548
+
549
+ transformer_outputs = self.transformer(
550
+ input_ids,
551
+ past_key_values=past_key_values,
552
+ attention_mask=attention_mask,
553
+ inputs_embeds=inputs_embeds,
554
+ use_cache=use_cache,
555
+ output_attentions=output_attentions,
556
+ output_hidden_states=output_hidden_states,
557
+ return_dict=return_dict,
558
+ )
559
+ hidden_states = transformer_outputs[0]
560
+
561
+ lm_logits = self.lm_head(hidden_states)
562
+
563
+ loss = None
564
+ if labels is not None:
565
+ # move labels to correct device to enable model parallelism
566
+ labels = labels.to(lm_logits.device)
567
+ # Flatten the tokens
568
+ loss = self.loss_function(
569
+ lm_logits,
570
+ labels,
571
+ vocab_size=self.config.vocab_size,
572
+ **kwargs,
573
+ )
574
+
575
+ if not return_dict:
576
+ output = (lm_logits,) + transformer_outputs[1:]
577
+ return ((loss,) + output) if loss is not None else output
578
+
579
+ return CausalLMOutputWithCrossAttentions(
580
+ loss=loss,
581
+ logits=lm_logits,
582
+ past_key_values=transformer_outputs.past_key_values,
583
+ hidden_states=transformer_outputs.hidden_states,
584
+ attentions=transformer_outputs.attentions,
585
+ )
586
+
587
+ def _reorder_cache(
588
+ self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
589
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
590
+ """
591
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
592
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
593
+ beam_idx at every generation step.
594
+
595
+ Output shares the same memory storage as `past`.
596
+ """
597
+ # Get a copy of `beam_idx` on all the devices where we need those indices.
598
+ device_to_beam_idx = {
599
+ past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
600
+ }
601
+ reordered_past = tuple(
602
+ (
603
+ layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
604
+ layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
605
+ )
606
+ for layer_past in past
607
+ )
608
+ return reordered_past
609
+
610
+
611
+ @add_start_docstrings(
612
+ """
613
+ The MPT Model transformer with a sequence classification head on top (linear layer).
614
+
615
+ [`MptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
616
+ (e.g. GPT-1) do.
617
+
618
+ Since it does classification on the last token, it requires to know the position of the last token. If a
619
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
620
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
621
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
622
+ each row of the batch).
623
+ """,
624
+ MPT_START_DOCSTRING,
625
+ )
626
+ class MptForSequenceClassification(MptPreTrainedModel):
627
+ def __init__(self, config: MptConfig):
628
+ super().__init__(config)
629
+ self.num_labels = config.num_labels
630
+ self.transformer = MptModel(config)
631
+ self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
632
+
633
+ # Initialize weights and apply final processing
634
+ self.post_init()
635
+
636
+ @add_start_docstrings_to_model_forward(MPT_INPUTS_DOCSTRING)
637
+ @add_code_sample_docstrings(
638
+ checkpoint=_CHECKPOINT_FOR_DOC,
639
+ output_type=SequenceClassifierOutputWithPast,
640
+ config_class=_CONFIG_FOR_DOC,
641
+ )
642
+ def forward(
643
+ self,
644
+ input_ids: Optional[torch.LongTensor] = None,
645
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
646
+ attention_mask: Optional[torch.Tensor] = None,
647
+ inputs_embeds: Optional[torch.Tensor] = None,
648
+ labels: Optional[torch.Tensor] = None,
649
+ use_cache: Optional[bool] = None,
650
+ output_attentions: Optional[bool] = None,
651
+ output_hidden_states: Optional[bool] = None,
652
+ return_dict: Optional[bool] = None,
653
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
654
+ r"""
655
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
656
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
657
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
658
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
659
+ """
660
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
661
+
662
+ transformer_outputs = self.transformer(
663
+ input_ids,
664
+ past_key_values=past_key_values,
665
+ attention_mask=attention_mask,
666
+ inputs_embeds=inputs_embeds,
667
+ use_cache=use_cache,
668
+ output_attentions=output_attentions,
669
+ output_hidden_states=output_hidden_states,
670
+ return_dict=return_dict,
671
+ )
672
+
673
+ hidden_states = transformer_outputs[0]
674
+ logits = self.score(hidden_states)
675
+
676
+ if input_ids is not None:
677
+ batch_size = input_ids.shape[0]
678
+ else:
679
+ batch_size = inputs_embeds.shape[0]
680
+
681
+ if self.config.pad_token_id is None and batch_size != 1:
682
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
683
+ if self.config.pad_token_id is None:
684
+ sequence_lengths = -1
685
+ else:
686
+ if input_ids is not None:
687
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
688
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
689
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
690
+ sequence_lengths = sequence_lengths.to(logits.device)
691
+ else:
692
+ sequence_lengths = -1
693
+ logger.warning_once(
694
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
695
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
696
+ )
697
+
698
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
699
+
700
+ loss = None
701
+ if labels is not None:
702
+ if self.config.problem_type is None:
703
+ if self.num_labels == 1:
704
+ self.config.problem_type = "regression"
705
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
706
+ self.config.problem_type = "single_label_classification"
707
+ else:
708
+ self.config.problem_type = "multi_label_classification"
709
+
710
+ if self.config.problem_type == "regression":
711
+ loss_fct = MSELoss()
712
+ if self.num_labels == 1:
713
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
714
+ else:
715
+ loss = loss_fct(pooled_logits, labels)
716
+ elif self.config.problem_type == "single_label_classification":
717
+ loss_fct = CrossEntropyLoss()
718
+ loss = loss_fct(pooled_logits, labels)
719
+ elif self.config.problem_type == "multi_label_classification":
720
+ loss_fct = BCEWithLogitsLoss()
721
+ loss = loss_fct(pooled_logits, labels)
722
+ if not return_dict:
723
+ output = (pooled_logits,) + transformer_outputs[1:]
724
+ return ((loss,) + output) if loss is not None else output
725
+
726
+ return SequenceClassifierOutputWithPast(
727
+ loss=loss,
728
+ logits=pooled_logits,
729
+ past_key_values=transformer_outputs.past_key_values,
730
+ hidden_states=transformer_outputs.hidden_states,
731
+ attentions=transformer_outputs.attentions,
732
+ )
733
+
734
+
735
+ @add_start_docstrings(
736
+ """
737
+ MPT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
738
+ Named-Entity-Recognition (NER) tasks.
739
+ """,
740
+ MPT_START_DOCSTRING,
741
+ )
742
+ class MptForTokenClassification(MptPreTrainedModel):
743
+ def __init__(self, config: MptConfig):
744
+ super().__init__(config)
745
+ self.num_labels = config.num_labels
746
+
747
+ self.transformer = MptModel(config)
748
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
749
+ classifier_dropout = config.classifier_dropout
750
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
751
+ classifier_dropout = config.hidden_dropout
752
+ else:
753
+ classifier_dropout = 0.1
754
+ self.dropout = nn.Dropout(classifier_dropout)
755
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
756
+
757
+ # Initialize weights and apply final processing
758
+ self.post_init()
759
+
760
+ @add_start_docstrings_to_model_forward(MPT_INPUTS_DOCSTRING)
761
+ @add_code_sample_docstrings(
762
+ checkpoint=_CHECKPOINT_FOR_DOC,
763
+ output_type=TokenClassifierOutput,
764
+ config_class=_CONFIG_FOR_DOC,
765
+ )
766
+ def forward(
767
+ self,
768
+ input_ids: Optional[torch.LongTensor] = None,
769
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
770
+ attention_mask: Optional[torch.Tensor] = None,
771
+ inputs_embeds: Optional[torch.Tensor] = None,
772
+ labels: Optional[torch.Tensor] = None,
773
+ use_cache: Optional[bool] = None,
774
+ output_attentions: Optional[bool] = None,
775
+ output_hidden_states: Optional[bool] = None,
776
+ return_dict: Optional[bool] = None,
777
+ **deprecated_arguments,
778
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
779
+ r"""
780
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
781
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
782
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
783
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
784
+ """
785
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
786
+
787
+ transformer_outputs = self.transformer(
788
+ input_ids,
789
+ past_key_values=past_key_values,
790
+ attention_mask=attention_mask,
791
+ inputs_embeds=inputs_embeds,
792
+ use_cache=use_cache,
793
+ output_attentions=output_attentions,
794
+ output_hidden_states=output_hidden_states,
795
+ return_dict=return_dict,
796
+ )
797
+
798
+ hidden_states = transformer_outputs[0]
799
+ hidden_states = self.dropout(hidden_states)
800
+ logits = self.classifier(hidden_states)
801
+
802
+ loss = None
803
+ if labels is not None:
804
+ # move labels to correct device to enable model parallelism
805
+ labels = labels.to(logits.device)
806
+ batch_size, seq_length = labels.shape
807
+ loss_fct = CrossEntropyLoss()
808
+ loss = loss_fct(
809
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
810
+ )
811
+
812
+ if not return_dict:
813
+ output = (logits,) + transformer_outputs[2:]
814
+ return ((loss,) + output) if loss is not None else output
815
+
816
+ return TokenClassifierOutput(
817
+ loss=loss,
818
+ logits=logits,
819
+ hidden_states=transformer_outputs.hidden_states,
820
+ attentions=transformer_outputs.attentions,
821
+ )
822
+
823
+
824
+ @add_start_docstrings(
825
+ """
826
+ The MPT Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD
827
+ (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
828
+ """,
829
+ MPT_START_DOCSTRING,
830
+ )
831
+ class MptForQuestionAnswering(MptPreTrainedModel):
832
+ def __init__(self, config):
833
+ super().__init__(config)
834
+ self.transformer = MptModel(config)
835
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
836
+
837
+ # Initialize weights and apply final processing
838
+ self.post_init()
839
+
840
+ @add_start_docstrings_to_model_forward(MPT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
841
+ def forward(
842
+ self,
843
+ input_ids: Optional[torch.LongTensor] = None,
844
+ attention_mask: Optional[torch.FloatTensor] = None,
845
+ inputs_embeds: Optional[torch.FloatTensor] = None,
846
+ start_positions: Optional[torch.LongTensor] = None,
847
+ end_positions: Optional[torch.LongTensor] = None,
848
+ output_attentions: Optional[bool] = None,
849
+ output_hidden_states: Optional[bool] = None,
850
+ return_dict: Optional[bool] = None,
851
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
852
+ r"""
853
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
854
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
855
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
856
+ are not taken into account for computing the loss.
857
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
858
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
859
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
860
+ are not taken into account for computing the loss.
861
+ """
862
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
863
+
864
+ outputs = self.transformer(
865
+ input_ids,
866
+ attention_mask=attention_mask,
867
+ inputs_embeds=inputs_embeds,
868
+ output_attentions=output_attentions,
869
+ output_hidden_states=output_hidden_states,
870
+ return_dict=return_dict,
871
+ )
872
+
873
+ sequence_output = outputs[0]
874
+
875
+ logits = self.qa_outputs(sequence_output)
876
+ start_logits, end_logits = logits.split(1, dim=-1)
877
+ start_logits = start_logits.squeeze(-1).contiguous()
878
+ end_logits = end_logits.squeeze(-1).contiguous()
879
+
880
+ total_loss = None
881
+ if start_positions is not None and end_positions is not None:
882
+ # If we are on multi-GPU, split add a dimension
883
+ if len(start_positions.size()) > 1:
884
+ start_positions = start_positions.squeeze(-1)
885
+ if len(end_positions.size()) > 1:
886
+ end_positions = end_positions.squeeze(-1)
887
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
888
+ ignored_index = start_logits.size(1)
889
+ start_positions = start_positions.clamp(0, ignored_index)
890
+ end_positions = end_positions.clamp(0, ignored_index)
891
+
892
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
893
+ start_loss = loss_fct(start_logits, start_positions)
894
+ end_loss = loss_fct(end_logits, end_positions)
895
+ total_loss = (start_loss + end_loss) / 2
896
+
897
+ if not return_dict:
898
+ output = (start_logits, end_logits) + outputs[2:]
899
+ return ((total_loss,) + output) if total_loss is not None else output
900
+
901
+ return QuestionAnsweringModelOutput(
902
+ loss=total_loss,
903
+ start_logits=start_logits,
904
+ end_logits=end_logits,
905
+ hidden_states=outputs.hidden_states,
906
+ attentions=outputs.attentions,
907
+ )
908
+
909
+
910
+ __all__ = [
911
+ "MptForCausalLM",
912
+ "MptModel",
913
+ "MptPreTrainedModel",
914
+ "MptForSequenceClassification",
915
+ "MptForTokenClassification",
916
+ "MptForQuestionAnswering",
917
+ ]
.venv/lib/python3.11/site-packages/transformers/models/olmo/__init__.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import (
17
+ OptionalDependencyNotAvailable,
18
+ _LazyModule,
19
+ is_sentencepiece_available,
20
+ is_tokenizers_available,
21
+ is_torch_available,
22
+ )
23
+
24
+
25
+ _import_structure = {
26
+ "configuration_olmo": ["OlmoConfig"],
27
+ }
28
+
29
+ try:
30
+ if not is_torch_available():
31
+ raise OptionalDependencyNotAvailable()
32
+ except OptionalDependencyNotAvailable:
33
+ pass
34
+ else:
35
+ _import_structure["modeling_olmo"] = [
36
+ "OlmoForCausalLM",
37
+ "OlmoModel",
38
+ "OlmoPreTrainedModel",
39
+ ]
40
+
41
+ if TYPE_CHECKING:
42
+ from .configuration_olmo import OlmoConfig
43
+
44
+ try:
45
+ if not is_torch_available():
46
+ raise OptionalDependencyNotAvailable()
47
+ except OptionalDependencyNotAvailable:
48
+ pass
49
+ else:
50
+ from .modeling_olmo import (
51
+ OlmoForCausalLM,
52
+ OlmoModel,
53
+ OlmoPreTrainedModel,
54
+ )
55
+
56
+ else:
57
+ import sys
58
+
59
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
.venv/lib/python3.11/site-packages/transformers/models/olmo/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.39 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/olmo/__pycache__/configuration_olmo.cpython-311.pyc ADDED
Binary file (8.39 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/olmo/__pycache__/modeling_olmo.cpython-311.pyc ADDED
Binary file (44.1 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/olmo/__pycache__/modular_olmo.cpython-311.pyc ADDED
Binary file (8.94 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/olmo/configuration_olmo.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """OLMo model configuration"""
21
+
22
+ from ...configuration_utils import PretrainedConfig
23
+ from ...utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class OlmoConfig(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`OlmoModel`]. It is used to instantiate an OLMo
32
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
33
+ defaults will yield a similar configuration to that of the [allenai/OLMo-7B-hf](https://huggingface.co/allenai/OLMo-7B-hf).
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 50304):
41
+ Vocabulary size of the OLMo model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`OlmoModel`]
43
+ hidden_size (`int`, *optional*, defaults to 4096):
44
+ Dimension of the hidden representations.
45
+ intermediate_size (`int`, *optional*, defaults to 11008):
46
+ Dimension of the MLP representations.
47
+ num_hidden_layers (`int`, *optional*, defaults to 32):
48
+ Number of hidden layers in the Transformer decoder.
49
+ num_attention_heads (`int`, *optional*, defaults to 32):
50
+ Number of attention heads for each attention layer in the Transformer decoder.
51
+ num_key_value_heads (`int`, *optional*):
52
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
53
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
54
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
55
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
56
+ by meanpooling all the original heads within that group. For more details checkout [this
57
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
58
+ `num_attention_heads`.
59
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
60
+ The non-linear activation function (function or string) in the decoder.
61
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
62
+ The maximum sequence length that this model might ever be used with.
63
+ initializer_range (`float`, *optional*, defaults to 0.02):
64
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
65
+ use_cache (`bool`, *optional*, defaults to `True`):
66
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
67
+ relevant if `config.is_decoder=True`.
68
+ pad_token_id (`int`, *optional*, defaults to 1):
69
+ Padding token id.
70
+ bos_token_id (`int`, *optional*):
71
+ Beginning of stream token id.
72
+ eos_token_id (`int`, *optional*, defaults to 50279):
73
+ End of stream token id.
74
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
75
+ Whether to tie weight embeddings
76
+ rope_theta (`float`, *optional*, defaults to 10000.0):
77
+ The base period of the RoPE embeddings.
78
+ rope_scaling (`Dict`, *optional*):
79
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
80
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
81
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
82
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
83
+ these scaling strategies behave:
84
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
85
+ experimental feature, subject to breaking API changes in future versions.
86
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
87
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
88
+ attention_dropout (`float`, *optional*, defaults to 0.0):
89
+ The dropout ratio for the attention probabilities.
90
+ clip_qkv (`float`, *optional*):
91
+ If not `None`, elements of query, key and value attention states are clipped so that their
92
+ absolute value does not exceed this value.
93
+
94
+ ```python
95
+ >>> from transformers import OlmoModel, OlmoConfig
96
+
97
+ >>> # Initializing a OLMo 7B style configuration
98
+ >>> configuration = OlmoConfig()
99
+
100
+ >>> # Initializing a model from the OLMo 7B style configuration
101
+ >>> model = OlmoModel(configuration)
102
+
103
+ >>> # Accessing the model configuration
104
+ >>> configuration = model.config
105
+ ```"""
106
+
107
+ model_type = "olmo"
108
+ keys_to_ignore_at_inference = ["past_key_values"]
109
+
110
+ def __init__(
111
+ self,
112
+ vocab_size=50304,
113
+ hidden_size=4096,
114
+ intermediate_size=11008,
115
+ num_hidden_layers=32,
116
+ num_attention_heads=32,
117
+ num_key_value_heads=None,
118
+ hidden_act="silu",
119
+ max_position_embeddings=2048,
120
+ initializer_range=0.02,
121
+ use_cache=True,
122
+ pad_token_id=1,
123
+ bos_token_id=None,
124
+ eos_token_id=50279,
125
+ tie_word_embeddings=False,
126
+ rope_theta=10000.0,
127
+ rope_scaling=None,
128
+ attention_bias=False,
129
+ attention_dropout=0.0,
130
+ clip_qkv=None,
131
+ **kwargs,
132
+ ):
133
+ self.vocab_size = vocab_size
134
+ self.max_position_embeddings = max_position_embeddings
135
+ self.hidden_size = hidden_size
136
+ self.intermediate_size = intermediate_size
137
+ self.num_hidden_layers = num_hidden_layers
138
+ self.num_attention_heads = num_attention_heads
139
+
140
+ # for backward compatibility
141
+ if num_key_value_heads is None:
142
+ num_key_value_heads = num_attention_heads
143
+
144
+ self.num_key_value_heads = num_key_value_heads
145
+ self.hidden_act = hidden_act
146
+ self.initializer_range = initializer_range
147
+ self.use_cache = use_cache
148
+ self.rope_theta = rope_theta
149
+ self.rope_scaling = rope_scaling
150
+ self._rope_scaling_validation()
151
+ self.attention_bias = attention_bias
152
+ self.attention_dropout = attention_dropout
153
+ self.clip_qkv = clip_qkv
154
+
155
+ super().__init__(
156
+ pad_token_id=pad_token_id,
157
+ bos_token_id=bos_token_id,
158
+ eos_token_id=eos_token_id,
159
+ tie_word_embeddings=tie_word_embeddings,
160
+ **kwargs,
161
+ )
162
+
163
+ def _rope_scaling_validation(self):
164
+ """
165
+ Validate the `rope_scaling` configuration.
166
+ """
167
+ if self.rope_scaling is None:
168
+ return
169
+
170
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
171
+ raise ValueError(
172
+ "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
173
+ )
174
+ rope_scaling_type = self.rope_scaling.get("type", None)
175
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
176
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
177
+ raise ValueError(
178
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
179
+ )
180
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
181
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
.venv/lib/python3.11/site-packages/transformers/models/olmo/modeling_olmo.py ADDED
@@ -0,0 +1,842 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/olmo/modular_olmo.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_olmo.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ from typing import Callable, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from ...activations import ACT2FN
14
+ from ...cache_utils import Cache, DynamicCache, StaticCache
15
+ from ...generation import GenerationMixin
16
+ from ...modeling_attn_mask_utils import AttentionMaskConverter
17
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
18
+ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
19
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
20
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
21
+ from ...processing_utils import Unpack
22
+ from ...utils import (
23
+ LossKwargs,
24
+ add_start_docstrings,
25
+ add_start_docstrings_to_model_forward,
26
+ logging,
27
+ replace_return_docstrings,
28
+ )
29
+ from .configuration_olmo import OlmoConfig
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+ _CONFIG_FOR_DOC = "OlmoConfig"
34
+
35
+
36
+ class OlmoLayerNorm(nn.Module):
37
+ """LayerNorm but with no learnable weight or bias."""
38
+
39
+ def __init__(self, hidden_size: int) -> None:
40
+ super().__init__()
41
+ self.normalized_shape = (hidden_size,)
42
+
43
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
44
+ orig_dtype = hidden_states.dtype
45
+ return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to(
46
+ orig_dtype
47
+ )
48
+
49
+
50
+ class OlmoMLP(nn.Module):
51
+ def __init__(self, config):
52
+ super().__init__()
53
+ self.config = config
54
+ self.hidden_size = config.hidden_size
55
+ self.intermediate_size = config.intermediate_size
56
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
57
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
58
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
59
+ self.act_fn = ACT2FN[config.hidden_act]
60
+
61
+ def forward(self, x):
62
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
63
+ return down_proj
64
+
65
+
66
+ def rotate_half(x):
67
+ """Rotates half the hidden dims of the input."""
68
+ x1 = x[..., : x.shape[-1] // 2]
69
+ x2 = x[..., x.shape[-1] // 2 :]
70
+ return torch.cat((-x2, x1), dim=-1)
71
+
72
+
73
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
74
+ """Applies Rotary Position Embedding to the query and key tensors.
75
+
76
+ Args:
77
+ q (`torch.Tensor`): The query tensor.
78
+ k (`torch.Tensor`): The key tensor.
79
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
80
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
81
+ position_ids (`torch.Tensor`, *optional*):
82
+ Deprecated and unused.
83
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
84
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
85
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
86
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
87
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
88
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
89
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
90
+ Returns:
91
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
92
+ """
93
+ cos = cos.unsqueeze(unsqueeze_dim)
94
+ sin = sin.unsqueeze(unsqueeze_dim)
95
+ q_embed = (q * cos) + (rotate_half(q) * sin)
96
+ k_embed = (k * cos) + (rotate_half(k) * sin)
97
+ return q_embed, k_embed
98
+
99
+
100
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
101
+ """
102
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
103
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
104
+ """
105
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
106
+ if n_rep == 1:
107
+ return hidden_states
108
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
109
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
110
+
111
+
112
+ def eager_attention_forward(
113
+ module: nn.Module,
114
+ query: torch.Tensor,
115
+ key: torch.Tensor,
116
+ value: torch.Tensor,
117
+ attention_mask: Optional[torch.Tensor],
118
+ scaling: float,
119
+ dropout: float = 0.0,
120
+ **kwargs,
121
+ ):
122
+ key_states = repeat_kv(key, module.num_key_value_groups)
123
+ value_states = repeat_kv(value, module.num_key_value_groups)
124
+
125
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
126
+ if attention_mask is not None:
127
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
128
+ attn_weights = attn_weights + causal_mask
129
+
130
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
131
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
132
+ attn_output = torch.matmul(attn_weights, value_states)
133
+ attn_output = attn_output.transpose(1, 2).contiguous()
134
+
135
+ return attn_output, attn_weights
136
+
137
+
138
+ class OlmoAttention(nn.Module):
139
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
140
+
141
+ def __init__(self, config: OlmoConfig, layer_idx: int):
142
+ super().__init__()
143
+ self.config = config
144
+ self.layer_idx = layer_idx
145
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
146
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
147
+ self.scaling = self.head_dim**-0.5
148
+ self.attention_dropout = config.attention_dropout
149
+ self.is_causal = True
150
+
151
+ self.q_proj = nn.Linear(
152
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
153
+ )
154
+ self.k_proj = nn.Linear(
155
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
156
+ )
157
+ self.v_proj = nn.Linear(
158
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
159
+ )
160
+ self.o_proj = nn.Linear(
161
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
162
+ )
163
+
164
+ def forward(
165
+ self,
166
+ hidden_states: torch.Tensor,
167
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
168
+ attention_mask: Optional[torch.Tensor],
169
+ past_key_value: Optional[Cache] = None,
170
+ cache_position: Optional[torch.LongTensor] = None,
171
+ **kwargs,
172
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
173
+ input_shape = hidden_states.shape[:-1]
174
+ hidden_shape = (*input_shape, -1, self.head_dim)
175
+
176
+ query_states = self.q_proj(hidden_states)
177
+ key_states = self.k_proj(hidden_states)
178
+ value_states = self.v_proj(hidden_states)
179
+
180
+ if self.config.clip_qkv is not None:
181
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
182
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
183
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
184
+
185
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
186
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
187
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
188
+
189
+ cos, sin = position_embeddings
190
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
191
+
192
+ if past_key_value is not None:
193
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
194
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
195
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
196
+
197
+ attention_interface: Callable = eager_attention_forward
198
+ if self.config._attn_implementation != "eager":
199
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
200
+ logger.warning_once(
201
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
202
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
203
+ )
204
+ else:
205
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
206
+
207
+ attn_output, attn_weights = attention_interface(
208
+ self,
209
+ query_states,
210
+ key_states,
211
+ value_states,
212
+ attention_mask,
213
+ dropout=0.0 if not self.training else self.attention_dropout,
214
+ scaling=self.scaling,
215
+ **kwargs,
216
+ )
217
+
218
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
219
+ attn_output = self.o_proj(attn_output)
220
+ return attn_output, attn_weights
221
+
222
+
223
+ class OlmoDecoderLayer(nn.Module):
224
+ def __init__(self, config: OlmoConfig, layer_idx: int):
225
+ super().__init__()
226
+ self.hidden_size = config.hidden_size
227
+ self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
228
+
229
+ self.mlp = OlmoMLP(config)
230
+ self.input_layernorm = OlmoLayerNorm(config.hidden_size)
231
+ self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
232
+
233
+ def forward(
234
+ self,
235
+ hidden_states: torch.Tensor,
236
+ attention_mask: Optional[torch.Tensor] = None,
237
+ position_ids: Optional[torch.LongTensor] = None,
238
+ past_key_value: Optional[Cache] = None,
239
+ output_attentions: Optional[bool] = False,
240
+ use_cache: Optional[bool] = False,
241
+ cache_position: Optional[torch.LongTensor] = None,
242
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
243
+ **kwargs: Unpack[FlashAttentionKwargs],
244
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
245
+ residual = hidden_states
246
+
247
+ hidden_states = self.input_layernorm(hidden_states)
248
+
249
+ # Self Attention
250
+ hidden_states, self_attn_weights = self.self_attn(
251
+ hidden_states=hidden_states,
252
+ attention_mask=attention_mask,
253
+ position_ids=position_ids,
254
+ past_key_value=past_key_value,
255
+ output_attentions=output_attentions,
256
+ use_cache=use_cache,
257
+ cache_position=cache_position,
258
+ position_embeddings=position_embeddings,
259
+ **kwargs,
260
+ )
261
+ hidden_states = residual + hidden_states
262
+
263
+ # Fully Connected
264
+ residual = hidden_states
265
+ hidden_states = self.post_attention_layernorm(hidden_states)
266
+ hidden_states = self.mlp(hidden_states)
267
+ hidden_states = residual + hidden_states
268
+
269
+ outputs = (hidden_states,)
270
+ if output_attentions:
271
+ outputs += (self_attn_weights,)
272
+
273
+ return outputs
274
+
275
+
276
+ class OlmoRotaryEmbedding(nn.Module):
277
+ def __init__(self, config: OlmoConfig, device=None):
278
+ super().__init__()
279
+ # BC: "rope_type" was originally "type"
280
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
281
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
282
+ else:
283
+ self.rope_type = "default"
284
+ self.max_seq_len_cached = config.max_position_embeddings
285
+ self.original_max_seq_len = config.max_position_embeddings
286
+
287
+ self.config = config
288
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
289
+
290
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
291
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
292
+ self.original_inv_freq = self.inv_freq
293
+
294
+ def _dynamic_frequency_update(self, position_ids, device):
295
+ """
296
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
297
+ 1 - growing beyond the cached sequence length (allow scaling)
298
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
299
+ """
300
+ seq_len = torch.max(position_ids) + 1
301
+ if seq_len > self.max_seq_len_cached: # growth
302
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
303
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
304
+ self.max_seq_len_cached = seq_len
305
+
306
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
307
+ # This .to() is needed if the model has been moved to a device after being initialized (because
308
+ # the buffer is automatically moved, but not the original copy)
309
+ self.original_inv_freq = self.original_inv_freq.to(device)
310
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
311
+ self.max_seq_len_cached = self.original_max_seq_len
312
+
313
+ @torch.no_grad()
314
+ def forward(self, x, position_ids):
315
+ if "dynamic" in self.rope_type:
316
+ self._dynamic_frequency_update(position_ids, device=x.device)
317
+
318
+ # Core RoPE block
319
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
320
+ position_ids_expanded = position_ids[:, None, :].float()
321
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
322
+ device_type = x.device.type
323
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
324
+ with torch.autocast(device_type=device_type, enabled=False):
325
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
326
+ emb = torch.cat((freqs, freqs), dim=-1)
327
+ cos = emb.cos()
328
+ sin = emb.sin()
329
+
330
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
331
+ cos = cos * self.attention_scaling
332
+ sin = sin * self.attention_scaling
333
+
334
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
335
+
336
+
337
+ OLMO_START_DOCSTRING = r"""
338
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
339
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
340
+ etc.)
341
+
342
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
343
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
344
+ and behavior.
345
+
346
+ Parameters:
347
+ config ([`OlmoConfig`]):
348
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
349
+ load the weights associated with the model, only the configuration. Check out the
350
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
351
+ """
352
+
353
+
354
+ @add_start_docstrings(
355
+ "The bare Olmo Model outputting raw hidden-states without any specific head on top.",
356
+ OLMO_START_DOCSTRING,
357
+ )
358
+ class OlmoPreTrainedModel(PreTrainedModel):
359
+ config_class = OlmoConfig
360
+ base_model_prefix = "model"
361
+ supports_gradient_checkpointing = True
362
+ _no_split_modules = ["OlmoDecoderLayer"]
363
+ _skip_keys_device_placement = ["past_key_values"]
364
+ _supports_flash_attn_2 = True
365
+ _supports_sdpa = True
366
+ _supports_flex_attn = True
367
+ _supports_cache_class = True
368
+ _supports_quantized_cache = True
369
+ _supports_static_cache = True
370
+
371
+ def _init_weights(self, module):
372
+ std = self.config.initializer_range
373
+ if isinstance(module, nn.Linear):
374
+ module.weight.data.normal_(mean=0.0, std=std)
375
+ if module.bias is not None:
376
+ module.bias.data.zero_()
377
+ elif isinstance(module, nn.Embedding):
378
+ module.weight.data.normal_(mean=0.0, std=std)
379
+ if module.padding_idx is not None:
380
+ module.weight.data[module.padding_idx].zero_()
381
+
382
+
383
+ OLMO_INPUTS_DOCSTRING = r"""
384
+ Args:
385
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
386
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
387
+ it.
388
+
389
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
390
+ [`PreTrainedTokenizer.__call__`] for details.
391
+
392
+ [What are input IDs?](../glossary#input-ids)
393
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
394
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
395
+
396
+ - 1 for tokens that are **not masked**,
397
+ - 0 for tokens that are **masked**.
398
+
399
+ [What are attention masks?](../glossary#attention-mask)
400
+
401
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
402
+ [`PreTrainedTokenizer.__call__`] for details.
403
+
404
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
405
+ `past_key_values`).
406
+
407
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
408
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
409
+ information on the default strategy.
410
+
411
+ - 1 indicates the head is **not masked**,
412
+ - 0 indicates the head is **masked**.
413
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
414
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
415
+ config.n_positions - 1]`.
416
+
417
+ [What are position IDs?](../glossary#position-ids)
418
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
419
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
420
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
421
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
422
+
423
+ Two formats are allowed:
424
+ - a [`~cache_utils.Cache`] instance, see our
425
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
426
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
427
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
428
+ cache format.
429
+
430
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
431
+ legacy cache format will be returned.
432
+
433
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
434
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
435
+ of shape `(batch_size, sequence_length)`.
436
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
437
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
438
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
439
+ model's internal embedding lookup matrix.
440
+ use_cache (`bool`, *optional*):
441
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
442
+ `past_key_values`).
443
+ output_attentions (`bool`, *optional*):
444
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
445
+ tensors for more detail.
446
+ output_hidden_states (`bool`, *optional*):
447
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
448
+ more detail.
449
+ return_dict (`bool`, *optional*):
450
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
451
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
452
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
453
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
454
+ the complete sequence length.
455
+ """
456
+
457
+
458
+ @add_start_docstrings(
459
+ "The bare Olmo Model outputting raw hidden-states without any specific head on top.",
460
+ OLMO_START_DOCSTRING,
461
+ )
462
+ class OlmoModel(OlmoPreTrainedModel):
463
+ """
464
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoDecoderLayer`]
465
+
466
+ Args:
467
+ config: OlmoConfig
468
+ """
469
+
470
+ def __init__(self, config: OlmoConfig):
471
+ super().__init__(config)
472
+ self.padding_idx = config.pad_token_id
473
+ self.vocab_size = config.vocab_size
474
+
475
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
476
+ self.layers = nn.ModuleList(
477
+ [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
478
+ )
479
+ self.norm = OlmoLayerNorm(config.hidden_size)
480
+ self.rotary_emb = OlmoRotaryEmbedding(config=config)
481
+ self.gradient_checkpointing = False
482
+
483
+ # Initialize weights and apply final processing
484
+ self.post_init()
485
+
486
+ def get_input_embeddings(self):
487
+ return self.embed_tokens
488
+
489
+ def set_input_embeddings(self, value):
490
+ self.embed_tokens = value
491
+
492
+ @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING)
493
+ def forward(
494
+ self,
495
+ input_ids: torch.LongTensor = None,
496
+ attention_mask: Optional[torch.Tensor] = None,
497
+ position_ids: Optional[torch.LongTensor] = None,
498
+ past_key_values: Optional[Cache] = None,
499
+ inputs_embeds: Optional[torch.FloatTensor] = None,
500
+ use_cache: Optional[bool] = None,
501
+ output_attentions: Optional[bool] = None,
502
+ output_hidden_states: Optional[bool] = None,
503
+ return_dict: Optional[bool] = None,
504
+ cache_position: Optional[torch.LongTensor] = None,
505
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
506
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
507
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
508
+ output_hidden_states = (
509
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
510
+ )
511
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
512
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
513
+
514
+ if (input_ids is None) ^ (inputs_embeds is not None):
515
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
516
+
517
+ if self.gradient_checkpointing and self.training and use_cache:
518
+ logger.warning_once(
519
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
520
+ )
521
+ use_cache = False
522
+
523
+ if inputs_embeds is None:
524
+ inputs_embeds = self.embed_tokens(input_ids)
525
+
526
+ if use_cache and past_key_values is None:
527
+ past_key_values = DynamicCache()
528
+
529
+ if cache_position is None:
530
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
531
+ cache_position = torch.arange(
532
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
533
+ )
534
+
535
+ if position_ids is None:
536
+ position_ids = cache_position.unsqueeze(0)
537
+
538
+ causal_mask = self._update_causal_mask(
539
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
540
+ )
541
+
542
+ hidden_states = inputs_embeds
543
+
544
+ # create position embeddings to be shared across the decoder layers
545
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
546
+
547
+ # decoder layers
548
+ all_hidden_states = () if output_hidden_states else None
549
+ all_self_attns = () if output_attentions else None
550
+
551
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
552
+ if output_hidden_states:
553
+ all_hidden_states += (hidden_states,)
554
+
555
+ if self.gradient_checkpointing and self.training:
556
+ layer_outputs = self._gradient_checkpointing_func(
557
+ decoder_layer.__call__,
558
+ hidden_states,
559
+ causal_mask,
560
+ position_ids,
561
+ past_key_values,
562
+ output_attentions,
563
+ use_cache,
564
+ cache_position,
565
+ position_embeddings,
566
+ )
567
+ else:
568
+ layer_outputs = decoder_layer(
569
+ hidden_states,
570
+ attention_mask=causal_mask,
571
+ position_ids=position_ids,
572
+ past_key_value=past_key_values,
573
+ output_attentions=output_attentions,
574
+ use_cache=use_cache,
575
+ cache_position=cache_position,
576
+ position_embeddings=position_embeddings,
577
+ **flash_attn_kwargs,
578
+ )
579
+
580
+ hidden_states = layer_outputs[0]
581
+
582
+ if output_attentions:
583
+ all_self_attns += (layer_outputs[1],)
584
+
585
+ hidden_states = self.norm(hidden_states)
586
+
587
+ # add hidden states from the last decoder layer
588
+ if output_hidden_states:
589
+ all_hidden_states += (hidden_states,)
590
+
591
+ output = BaseModelOutputWithPast(
592
+ last_hidden_state=hidden_states,
593
+ past_key_values=past_key_values if use_cache else None,
594
+ hidden_states=all_hidden_states,
595
+ attentions=all_self_attns,
596
+ )
597
+ return output if return_dict else output.to_tuple()
598
+
599
+ def _update_causal_mask(
600
+ self,
601
+ attention_mask: torch.Tensor,
602
+ input_tensor: torch.Tensor,
603
+ cache_position: torch.Tensor,
604
+ past_key_values: Cache,
605
+ output_attentions: bool,
606
+ ):
607
+ if self.config._attn_implementation == "flash_attention_2":
608
+ if attention_mask is not None and (attention_mask == 0.0).any():
609
+ return attention_mask
610
+ return None
611
+
612
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
613
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
614
+ # to infer the attention mask.
615
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
616
+ using_static_cache = isinstance(past_key_values, StaticCache)
617
+
618
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
619
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
620
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
621
+ attention_mask,
622
+ inputs_embeds=input_tensor,
623
+ past_key_values_length=past_seen_tokens,
624
+ is_training=self.training,
625
+ ):
626
+ return None
627
+
628
+ dtype, device = input_tensor.dtype, input_tensor.device
629
+ sequence_length = input_tensor.shape[1]
630
+ if using_static_cache:
631
+ target_length = past_key_values.get_max_cache_shape()
632
+ else:
633
+ target_length = (
634
+ attention_mask.shape[-1]
635
+ if isinstance(attention_mask, torch.Tensor)
636
+ else past_seen_tokens + sequence_length + 1
637
+ )
638
+
639
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
640
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
641
+ attention_mask,
642
+ sequence_length=sequence_length,
643
+ target_length=target_length,
644
+ dtype=dtype,
645
+ device=device,
646
+ cache_position=cache_position,
647
+ batch_size=input_tensor.shape[0],
648
+ )
649
+
650
+ if (
651
+ self.config._attn_implementation == "sdpa"
652
+ and attention_mask is not None
653
+ and attention_mask.device.type == "cuda"
654
+ and not output_attentions
655
+ ):
656
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
657
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
658
+ # Details: https://github.com/pytorch/pytorch/issues/110213
659
+ min_dtype = torch.finfo(dtype).min
660
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
661
+
662
+ return causal_mask
663
+
664
+ @staticmethod
665
+ def _prepare_4d_causal_attention_mask_with_cache_position(
666
+ attention_mask: torch.Tensor,
667
+ sequence_length: int,
668
+ target_length: int,
669
+ dtype: torch.dtype,
670
+ device: torch.device,
671
+ cache_position: torch.Tensor,
672
+ batch_size: int,
673
+ **kwargs,
674
+ ):
675
+ """
676
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
677
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
678
+
679
+ Args:
680
+ attention_mask (`torch.Tensor`):
681
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
682
+ `(batch_size, 1, query_length, key_value_length)`.
683
+ sequence_length (`int`):
684
+ The sequence length being processed.
685
+ target_length (`int`):
686
+ The target length: when generating with static cache, the mask should be as long as the static cache,
687
+ to account for the 0 padding, the part of the cache that is not filled yet.
688
+ dtype (`torch.dtype`):
689
+ The dtype to use for the 4D attention mask.
690
+ device (`torch.device`):
691
+ The device to plcae the 4D attention mask on.
692
+ cache_position (`torch.Tensor`):
693
+ Indices depicting the position of the input sequence tokens in the sequence.
694
+ batch_size (`torch.Tensor`):
695
+ Batch size.
696
+ """
697
+ if attention_mask is not None and attention_mask.dim() == 4:
698
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
699
+ causal_mask = attention_mask
700
+ else:
701
+ min_dtype = torch.finfo(dtype).min
702
+ causal_mask = torch.full(
703
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
704
+ )
705
+ if sequence_length != 1:
706
+ causal_mask = torch.triu(causal_mask, diagonal=1)
707
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
708
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
709
+ if attention_mask is not None:
710
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
711
+ mask_length = attention_mask.shape[-1]
712
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
713
+ padding_mask = padding_mask == 0
714
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
715
+ padding_mask, min_dtype
716
+ )
717
+
718
+ return causal_mask
719
+
720
+
721
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
722
+
723
+
724
+ class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
725
+ _tied_weights_keys = ["lm_head.weight"]
726
+ _tp_plan = {"lm_head": "colwise_rep"}
727
+
728
+ def __init__(self, config):
729
+ super().__init__(config)
730
+ self.model = OlmoModel(config)
731
+ self.vocab_size = config.vocab_size
732
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
733
+
734
+ # Initialize weights and apply final processing
735
+ self.post_init()
736
+
737
+ def get_input_embeddings(self):
738
+ return self.model.embed_tokens
739
+
740
+ def set_input_embeddings(self, value):
741
+ self.model.embed_tokens = value
742
+
743
+ def get_output_embeddings(self):
744
+ return self.lm_head
745
+
746
+ def set_output_embeddings(self, new_embeddings):
747
+ self.lm_head = new_embeddings
748
+
749
+ def set_decoder(self, decoder):
750
+ self.model = decoder
751
+
752
+ def get_decoder(self):
753
+ return self.model
754
+
755
+ @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING)
756
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
757
+ def forward(
758
+ self,
759
+ input_ids: torch.LongTensor = None,
760
+ attention_mask: Optional[torch.Tensor] = None,
761
+ position_ids: Optional[torch.LongTensor] = None,
762
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
763
+ inputs_embeds: Optional[torch.FloatTensor] = None,
764
+ labels: Optional[torch.LongTensor] = None,
765
+ use_cache: Optional[bool] = None,
766
+ output_attentions: Optional[bool] = None,
767
+ output_hidden_states: Optional[bool] = None,
768
+ return_dict: Optional[bool] = None,
769
+ cache_position: Optional[torch.LongTensor] = None,
770
+ num_logits_to_keep: int = 0,
771
+ **kwargs: Unpack[KwargsForCausalLM],
772
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
773
+ r"""
774
+ Args:
775
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
776
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
777
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
778
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
779
+
780
+ num_logits_to_keep (`int`, *optional*):
781
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
782
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
783
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
784
+
785
+ Returns:
786
+
787
+ Example:
788
+
789
+ ```python
790
+ >>> from transformers import AutoTokenizer, OlmoForCausalLM
791
+
792
+ >>> model = OlmoForCausalLM.from_pretrained("meta-olmo/Olmo-2-7b-hf")
793
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo/Olmo-2-7b-hf")
794
+
795
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
796
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
797
+
798
+ >>> # Generate
799
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
800
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
801
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
802
+ ```"""
803
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
804
+ output_hidden_states = (
805
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
806
+ )
807
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
808
+
809
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
810
+ outputs = self.model(
811
+ input_ids=input_ids,
812
+ attention_mask=attention_mask,
813
+ position_ids=position_ids,
814
+ past_key_values=past_key_values,
815
+ inputs_embeds=inputs_embeds,
816
+ use_cache=use_cache,
817
+ output_attentions=output_attentions,
818
+ output_hidden_states=output_hidden_states,
819
+ return_dict=return_dict,
820
+ cache_position=cache_position,
821
+ **kwargs,
822
+ )
823
+
824
+ hidden_states = outputs[0]
825
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
826
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
827
+
828
+ loss = None
829
+ if labels is not None:
830
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
831
+
832
+ if not return_dict:
833
+ output = (logits,) + outputs[1:]
834
+ return (loss,) + output if loss is not None else output
835
+
836
+ return CausalLMOutputWithPast(
837
+ loss=loss,
838
+ logits=logits,
839
+ past_key_values=outputs.past_key_values,
840
+ hidden_states=outputs.hidden_states,
841
+ attentions=outputs.attentions,
842
+ )
.venv/lib/python3.11/site-packages/transformers/models/olmo/modular_olmo.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.utils.checkpoint
7
+
8
+ from ...cache_utils import Cache
9
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
10
+ from ...utils import logging
11
+ from ..llama.modeling_llama import (
12
+ LlamaAttention,
13
+ LlamaDecoderLayer,
14
+ LlamaForCausalLM,
15
+ LlamaMLP,
16
+ LlamaModel,
17
+ apply_rotary_pos_emb,
18
+ eager_attention_forward,
19
+ )
20
+ from .configuration_olmo import OlmoConfig
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class OlmoLayerNorm(nn.Module):
27
+ """LayerNorm but with no learnable weight or bias."""
28
+
29
+ def __init__(self, hidden_size: int) -> None:
30
+ super().__init__()
31
+ self.normalized_shape = (hidden_size,)
32
+
33
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
34
+ orig_dtype = hidden_states.dtype
35
+ return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to(
36
+ orig_dtype
37
+ )
38
+
39
+
40
+ class OlmoMLP(LlamaMLP):
41
+ def __init__(self, config):
42
+ super().__init__(config)
43
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
44
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
45
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
46
+
47
+
48
+ class OlmoAttention(LlamaAttention):
49
+ def forward(
50
+ self,
51
+ hidden_states: torch.Tensor,
52
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
53
+ attention_mask: Optional[torch.Tensor],
54
+ past_key_value: Optional[Cache] = None,
55
+ cache_position: Optional[torch.LongTensor] = None,
56
+ **kwargs,
57
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
58
+ input_shape = hidden_states.shape[:-1]
59
+ hidden_shape = (*input_shape, -1, self.head_dim)
60
+
61
+ query_states = self.q_proj(hidden_states)
62
+ key_states = self.k_proj(hidden_states)
63
+ value_states = self.v_proj(hidden_states)
64
+
65
+ if self.config.clip_qkv is not None:
66
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
67
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
68
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
69
+
70
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
71
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
72
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
73
+
74
+ cos, sin = position_embeddings
75
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
76
+
77
+ if past_key_value is not None:
78
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
79
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
80
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
81
+
82
+ attention_interface: Callable = eager_attention_forward
83
+ if self.config._attn_implementation != "eager":
84
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
85
+ logger.warning_once(
86
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
87
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
88
+ )
89
+ else:
90
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
91
+
92
+ attn_output, attn_weights = attention_interface(
93
+ self,
94
+ query_states,
95
+ key_states,
96
+ value_states,
97
+ attention_mask,
98
+ dropout=0.0 if not self.training else self.attention_dropout,
99
+ scaling=self.scaling,
100
+ **kwargs,
101
+ )
102
+
103
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
104
+ attn_output = self.o_proj(attn_output)
105
+ return attn_output, attn_weights
106
+
107
+
108
+ class OlmoDecoderLayer(LlamaDecoderLayer):
109
+ def __init__(self, config: OlmoConfig, layer_idx: int):
110
+ super().__init__(config, layer_idx)
111
+ self.input_layernorm = OlmoLayerNorm(config.hidden_size)
112
+ self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
113
+ self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx)
114
+
115
+
116
+ class OlmoModel(LlamaModel):
117
+ def __init__(self, config: OlmoConfig):
118
+ super().__init__(config)
119
+ self.layers = nn.ModuleList(
120
+ [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
121
+ )
122
+ self.norm = OlmoLayerNorm(config.hidden_size)
123
+
124
+
125
+ class OlmoForCausalLM(LlamaForCausalLM):
126
+ pass
.venv/lib/python3.11/site-packages/transformers/models/rt_detr/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import TYPE_CHECKING
17
+
18
+ from ...utils import _LazyModule
19
+ from ...utils.import_utils import define_import_structure
20
+
21
+
22
+ if TYPE_CHECKING:
23
+ from .configuration_rt_detr import *
24
+ from .configuration_rt_detr_resnet import *
25
+ from .image_processing_rt_detr import *
26
+ from .image_processing_rt_detr_fast import *
27
+ from .modeling_rt_detr import *
28
+ from .modeling_rt_detr_resnet import *
29
+ else:
30
+ import sys
31
+
32
+ _file = globals()["__file__"]
33
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
.venv/lib/python3.11/site-packages/transformers/models/rt_detr/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (961 Bytes). View file
 
.venv/lib/python3.11/site-packages/transformers/models/rt_detr/__pycache__/configuration_rt_detr.cpython-311.pyc ADDED
Binary file (16.8 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/rt_detr/__pycache__/configuration_rt_detr_resnet.cpython-311.pyc ADDED
Binary file (5.88 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/rt_detr/__pycache__/image_processing_rt_detr.cpython-311.pyc ADDED
Binary file (56.3 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/rt_detr/__pycache__/image_processing_rt_detr_fast.cpython-311.pyc ADDED
Binary file (42.2 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/rt_detr/__pycache__/modeling_rt_detr_resnet.cpython-311.pyc ADDED
Binary file (21.3 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/rt_detr/__pycache__/modular_rt_detr.cpython-311.pyc ADDED
Binary file (33.7 kB). View file
 
.venv/lib/python3.11/site-packages/transformers/models/rt_detr/configuration_rt_detr.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """RT-DETR model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+ from ...utils.backbone_utils import verify_backbone_config_arguments
20
+ from ..auto import CONFIG_MAPPING
21
+ from .configuration_rt_detr_resnet import RTDetrResNetConfig
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class RTDetrConfig(PretrainedConfig):
28
+ r"""
29
+ This is the configuration class to store the configuration of a [`RTDetrModel`]. It is used to instantiate a
30
+ RT-DETR model according to the specified arguments, defining the model architecture. Instantiating a configuration
31
+ with the defaults will yield a similar configuration to that of the RT-DETR
32
+ [checkpoing/todo](https://huggingface.co/checkpoing/todo) architecture.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+ Args:
38
+ initializer_range (`float`, *optional*, defaults to 0.01):
39
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
40
+ initializer_bias_prior_prob (`float`, *optional*):
41
+ The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
42
+ If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
43
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
44
+ The epsilon used by the layer normalization layers.
45
+ batch_norm_eps (`float`, *optional*, defaults to 1e-05):
46
+ The epsilon used by the batch normalization layers.
47
+ backbone_config (`Dict`, *optional*, defaults to `RTDetrResNetConfig()`):
48
+ The configuration of the backbone model.
49
+ backbone (`str`, *optional*):
50
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
51
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
52
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
53
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
54
+ Whether to use pretrained weights for the backbone.
55
+ use_timm_backbone (`bool`, *optional*, defaults to `False`):
56
+ Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
57
+ library.
58
+ freeze_backbone_batch_norms (`bool`, *optional*, defaults to `True`):
59
+ Whether to freeze the batch normalization layers in the backbone.
60
+ backbone_kwargs (`dict`, *optional*):
61
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
62
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
63
+ encoder_hidden_dim (`int`, *optional*, defaults to 256):
64
+ Dimension of the layers in hybrid encoder.
65
+ encoder_in_channels (`list`, *optional*, defaults to `[512, 1024, 2048]`):
66
+ Multi level features input for encoder.
67
+ feat_strides (`List[int]`, *optional*, defaults to `[8, 16, 32]`):
68
+ Strides used in each feature map.
69
+ encoder_layers (`int`, *optional*, defaults to 1):
70
+ Total of layers to be used by the encoder.
71
+ encoder_ffn_dim (`int`, *optional*, defaults to 1024):
72
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
73
+ encoder_attention_heads (`int`, *optional*, defaults to 8):
74
+ Number of attention heads for each attention layer in the Transformer encoder.
75
+ dropout (`float`, *optional*, defaults to 0.0):
76
+ The ratio for all dropout layers.
77
+ activation_dropout (`float`, *optional*, defaults to 0.0):
78
+ The dropout ratio for activations inside the fully connected layer.
79
+ encode_proj_layers (`List[int]`, *optional*, defaults to `[2]`):
80
+ Indexes of the projected layers to be used in the encoder.
81
+ positional_encoding_temperature (`int`, *optional*, defaults to 10000):
82
+ The temperature parameter used to create the positional encodings.
83
+ encoder_activation_function (`str`, *optional*, defaults to `"gelu"`):
84
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
85
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
86
+ activation_function (`str`, *optional*, defaults to `"silu"`):
87
+ The non-linear activation function (function or string) in the general layer. If string, `"gelu"`,
88
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
89
+ eval_size (`Tuple[int, int]`, *optional*):
90
+ Height and width used to computes the effective height and width of the position embeddings after taking
91
+ into account the stride.
92
+ normalize_before (`bool`, *optional*, defaults to `False`):
93
+ Determine whether to apply layer normalization in the transformer encoder layer before self-attention and
94
+ feed-forward modules.
95
+ hidden_expansion (`float`, *optional*, defaults to 1.0):
96
+ Expansion ratio to enlarge the dimension size of RepVGGBlock and CSPRepLayer.
97
+ d_model (`int`, *optional*, defaults to 256):
98
+ Dimension of the layers exclude hybrid encoder.
99
+ num_queries (`int`, *optional*, defaults to 300):
100
+ Number of object queries.
101
+ decoder_in_channels (`list`, *optional*, defaults to `[256, 256, 256]`):
102
+ Multi level features dimension for decoder
103
+ decoder_ffn_dim (`int`, *optional*, defaults to 1024):
104
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
105
+ num_feature_levels (`int`, *optional*, defaults to 3):
106
+ The number of input feature levels.
107
+ decoder_n_points (`int`, *optional*, defaults to 4):
108
+ The number of sampled keys in each feature level for each attention head in the decoder.
109
+ decoder_layers (`int`, *optional*, defaults to 6):
110
+ Number of decoder layers.
111
+ decoder_attention_heads (`int`, *optional*, defaults to 8):
112
+ Number of attention heads for each attention layer in the Transformer decoder.
113
+ decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
114
+ The non-linear activation function (function or string) in the decoder. If string, `"gelu"`,
115
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
116
+ attention_dropout (`float`, *optional*, defaults to 0.0):
117
+ The dropout ratio for the attention probabilities.
118
+ num_denoising (`int`, *optional*, defaults to 100):
119
+ The total number of denoising tasks or queries to be used for contrastive denoising.
120
+ label_noise_ratio (`float`, *optional*, defaults to 0.5):
121
+ The fraction of denoising labels to which random noise should be added.
122
+ box_noise_scale (`float`, *optional*, defaults to 1.0):
123
+ Scale or magnitude of noise to be added to the bounding boxes.
124
+ learn_initial_query (`bool`, *optional*, defaults to `False`):
125
+ Indicates whether the initial query embeddings for the decoder should be learned during training
126
+ anchor_image_size (`Tuple[int, int]`, *optional*):
127
+ Height and width of the input image used during evaluation to generate the bounding box anchors. If None, automatic generate anchor is applied.
128
+ disable_custom_kernels (`bool`, *optional*, defaults to `True`):
129
+ Whether to disable custom kernels.
130
+ with_box_refine (`bool`, *optional*, defaults to `True`):
131
+ Whether to apply iterative bounding box refinement, where each decoder layer refines the bounding boxes
132
+ based on the predictions from the previous layer.
133
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
134
+ Whether the architecture has an encoder decoder structure.
135
+ matcher_alpha (`float`, *optional*, defaults to 0.25):
136
+ Parameter alpha used by the Hungarian Matcher.
137
+ matcher_gamma (`float`, *optional*, defaults to 2.0):
138
+ Parameter gamma used by the Hungarian Matcher.
139
+ matcher_class_cost (`float`, *optional*, defaults to 2.0):
140
+ The relative weight of the class loss used by the Hungarian Matcher.
141
+ matcher_bbox_cost (`float`, *optional*, defaults to 5.0):
142
+ The relative weight of the bounding box loss used by the Hungarian Matcher.
143
+ matcher_giou_cost (`float`, *optional*, defaults to 2.0):
144
+ The relative weight of the giou loss of used by the Hungarian Matcher.
145
+ use_focal_loss (`bool`, *optional*, defaults to `True`):
146
+ Parameter informing if focal focal should be used.
147
+ auxiliary_loss (`bool`, *optional*, defaults to `True`):
148
+ Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
149
+ focal_loss_alpha (`float`, *optional*, defaults to 0.75):
150
+ Parameter alpha used to compute the focal loss.
151
+ focal_loss_gamma (`float`, *optional*, defaults to 2.0):
152
+ Parameter gamma used to compute the focal loss.
153
+ weight_loss_vfl (`float`, *optional*, defaults to 1.0):
154
+ Relative weight of the varifocal loss in the object detection loss.
155
+ weight_loss_bbox (`float`, *optional*, defaults to 5.0):
156
+ Relative weight of the L1 bounding box loss in the object detection loss.
157
+ weight_loss_giou (`float`, *optional*, defaults to 2.0):
158
+ Relative weight of the generalized IoU loss in the object detection loss.
159
+ eos_coefficient (`float`, *optional*, defaults to 0.0001):
160
+ Relative classification weight of the 'no-object' class in the object detection loss.
161
+
162
+ Examples:
163
+
164
+ ```python
165
+ >>> from transformers import RTDetrConfig, RTDetrModel
166
+
167
+ >>> # Initializing a RT-DETR configuration
168
+ >>> configuration = RTDetrConfig()
169
+
170
+ >>> # Initializing a model (with random weights) from the configuration
171
+ >>> model = RTDetrModel(configuration)
172
+
173
+ >>> # Accessing the model configuration
174
+ >>> configuration = model.config
175
+ ```"""
176
+
177
+ model_type = "rt_detr"
178
+ layer_types = ["basic", "bottleneck"]
179
+ attribute_map = {
180
+ "hidden_size": "d_model",
181
+ "num_attention_heads": "encoder_attention_heads",
182
+ }
183
+
184
+ def __init__(
185
+ self,
186
+ initializer_range=0.01,
187
+ initializer_bias_prior_prob=None,
188
+ layer_norm_eps=1e-5,
189
+ batch_norm_eps=1e-5,
190
+ # backbone
191
+ backbone_config=None,
192
+ backbone=None,
193
+ use_pretrained_backbone=False,
194
+ use_timm_backbone=False,
195
+ freeze_backbone_batch_norms=True,
196
+ backbone_kwargs=None,
197
+ # encoder HybridEncoder
198
+ encoder_hidden_dim=256,
199
+ encoder_in_channels=[512, 1024, 2048],
200
+ feat_strides=[8, 16, 32],
201
+ encoder_layers=1,
202
+ encoder_ffn_dim=1024,
203
+ encoder_attention_heads=8,
204
+ dropout=0.0,
205
+ activation_dropout=0.0,
206
+ encode_proj_layers=[2],
207
+ positional_encoding_temperature=10000,
208
+ encoder_activation_function="gelu",
209
+ activation_function="silu",
210
+ eval_size=None,
211
+ normalize_before=False,
212
+ hidden_expansion=1.0,
213
+ # decoder RTDetrTransformer
214
+ d_model=256,
215
+ num_queries=300,
216
+ decoder_in_channels=[256, 256, 256],
217
+ decoder_ffn_dim=1024,
218
+ num_feature_levels=3,
219
+ decoder_n_points=4,
220
+ decoder_layers=6,
221
+ decoder_attention_heads=8,
222
+ decoder_activation_function="relu",
223
+ attention_dropout=0.0,
224
+ num_denoising=100,
225
+ label_noise_ratio=0.5,
226
+ box_noise_scale=1.0,
227
+ learn_initial_query=False,
228
+ anchor_image_size=None,
229
+ disable_custom_kernels=True,
230
+ with_box_refine=True,
231
+ is_encoder_decoder=True,
232
+ # Loss
233
+ matcher_alpha=0.25,
234
+ matcher_gamma=2.0,
235
+ matcher_class_cost=2.0,
236
+ matcher_bbox_cost=5.0,
237
+ matcher_giou_cost=2.0,
238
+ use_focal_loss=True,
239
+ auxiliary_loss=True,
240
+ focal_loss_alpha=0.75,
241
+ focal_loss_gamma=2.0,
242
+ weight_loss_vfl=1.0,
243
+ weight_loss_bbox=5.0,
244
+ weight_loss_giou=2.0,
245
+ eos_coefficient=1e-4,
246
+ **kwargs,
247
+ ):
248
+ self.initializer_range = initializer_range
249
+ self.initializer_bias_prior_prob = initializer_bias_prior_prob
250
+ self.layer_norm_eps = layer_norm_eps
251
+ self.batch_norm_eps = batch_norm_eps
252
+ # backbone
253
+ if backbone_config is None and backbone is None:
254
+ logger.info(
255
+ "`backbone_config` and `backbone` are `None`. Initializing the config with the default `RTDetr-ResNet` backbone."
256
+ )
257
+ backbone_config = RTDetrResNetConfig(
258
+ num_channels=3,
259
+ embedding_size=64,
260
+ hidden_sizes=[256, 512, 1024, 2048],
261
+ depths=[3, 4, 6, 3],
262
+ layer_type="bottleneck",
263
+ hidden_act="relu",
264
+ downsample_in_first_stage=False,
265
+ downsample_in_bottleneck=False,
266
+ out_features=None,
267
+ out_indices=[2, 3, 4],
268
+ )
269
+ elif isinstance(backbone_config, dict):
270
+ backbone_model_type = backbone_config.pop("model_type")
271
+ config_class = CONFIG_MAPPING[backbone_model_type]
272
+ backbone_config = config_class.from_dict(backbone_config)
273
+
274
+ verify_backbone_config_arguments(
275
+ use_timm_backbone=use_timm_backbone,
276
+ use_pretrained_backbone=use_pretrained_backbone,
277
+ backbone=backbone,
278
+ backbone_config=backbone_config,
279
+ backbone_kwargs=backbone_kwargs,
280
+ )
281
+
282
+ self.backbone_config = backbone_config
283
+ self.backbone = backbone
284
+ self.use_pretrained_backbone = use_pretrained_backbone
285
+ self.use_timm_backbone = use_timm_backbone
286
+ self.freeze_backbone_batch_norms = freeze_backbone_batch_norms
287
+ self.backbone_kwargs = backbone_kwargs
288
+ # encoder
289
+ self.encoder_hidden_dim = encoder_hidden_dim
290
+ self.encoder_in_channels = encoder_in_channels
291
+ self.feat_strides = feat_strides
292
+ self.encoder_attention_heads = encoder_attention_heads
293
+ self.encoder_ffn_dim = encoder_ffn_dim
294
+ self.dropout = dropout
295
+ self.activation_dropout = activation_dropout
296
+ self.encode_proj_layers = encode_proj_layers
297
+ self.encoder_layers = encoder_layers
298
+ self.positional_encoding_temperature = positional_encoding_temperature
299
+ self.eval_size = eval_size
300
+ self.normalize_before = normalize_before
301
+ self.encoder_activation_function = encoder_activation_function
302
+ self.activation_function = activation_function
303
+ self.hidden_expansion = hidden_expansion
304
+ # decoder
305
+ self.d_model = d_model
306
+ self.num_queries = num_queries
307
+ self.decoder_ffn_dim = decoder_ffn_dim
308
+ self.decoder_in_channels = decoder_in_channels
309
+ self.num_feature_levels = num_feature_levels
310
+ self.decoder_n_points = decoder_n_points
311
+ self.decoder_layers = decoder_layers
312
+ self.decoder_attention_heads = decoder_attention_heads
313
+ self.decoder_activation_function = decoder_activation_function
314
+ self.attention_dropout = attention_dropout
315
+ self.num_denoising = num_denoising
316
+ self.label_noise_ratio = label_noise_ratio
317
+ self.box_noise_scale = box_noise_scale
318
+ self.learn_initial_query = learn_initial_query
319
+ self.anchor_image_size = anchor_image_size
320
+ self.auxiliary_loss = auxiliary_loss
321
+ self.disable_custom_kernels = disable_custom_kernels
322
+ self.with_box_refine = with_box_refine
323
+ # Loss
324
+ self.matcher_alpha = matcher_alpha
325
+ self.matcher_gamma = matcher_gamma
326
+ self.matcher_class_cost = matcher_class_cost
327
+ self.matcher_bbox_cost = matcher_bbox_cost
328
+ self.matcher_giou_cost = matcher_giou_cost
329
+ self.use_focal_loss = use_focal_loss
330
+ self.focal_loss_alpha = focal_loss_alpha
331
+ self.focal_loss_gamma = focal_loss_gamma
332
+ self.weight_loss_vfl = weight_loss_vfl
333
+ self.weight_loss_bbox = weight_loss_bbox
334
+ self.weight_loss_giou = weight_loss_giou
335
+ self.eos_coefficient = eos_coefficient
336
+ super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)
337
+
338
+ @property
339
+ def num_attention_heads(self) -> int:
340
+ return self.encoder_attention_heads
341
+
342
+ @property
343
+ def hidden_size(self) -> int:
344
+ return self.d_model
345
+
346
+ @classmethod
347
+ def from_backbone_configs(cls, backbone_config: PretrainedConfig, **kwargs):
348
+ """Instantiate a [`RTDetrConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
349
+ configuration.
350
+
351
+ Args:
352
+ backbone_config ([`PretrainedConfig`]):
353
+ The backbone configuration.
354
+
355
+ Returns:
356
+ [`RTDetrConfig`]: An instance of a configuration object
357
+ """
358
+ return cls(
359
+ backbone_config=backbone_config,
360
+ **kwargs,
361
+ )
362
+
363
+
364
+ __all__ = ["RTDetrConfig"]
.venv/lib/python3.11/site-packages/transformers/models/rt_detr/configuration_rt_detr_resnet.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """RT-DETR ResNet model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+ from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class RTDetrResNetConfig(BackboneConfigMixin, PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`RTDetrResnetBackbone`]. It is used to instantiate an
28
+ ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of the ResNet
30
+ [microsoft/resnet-50](https://huggingface.co/microsoft/resnet-50) architecture.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+ Args:
36
+ num_channels (`int`, *optional*, defaults to 3):
37
+ The number of input channels.
38
+ embedding_size (`int`, *optional*, defaults to 64):
39
+ Dimensionality (hidden size) for the embedding layer.
40
+ hidden_sizes (`List[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`):
41
+ Dimensionality (hidden size) at each stage.
42
+ depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 3]`):
43
+ Depth (number of layers) for each stage.
44
+ layer_type (`str`, *optional*, defaults to `"bottleneck"`):
45
+ The layer to use, it can be either `"basic"` (used for smaller models, like resnet-18 or resnet-34) or
46
+ `"bottleneck"` (used for larger models like resnet-50 and above).
47
+ hidden_act (`str`, *optional*, defaults to `"relu"`):
48
+ The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"`
49
+ are supported.
50
+ downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
51
+ If `True`, the first stage will downsample the inputs using a `stride` of 2.
52
+ downsample_in_bottleneck (`bool`, *optional*, defaults to `False`):
53
+ If `True`, the first conv 1x1 in ResNetBottleNeckLayer will downsample the inputs using a `stride` of 2.
54
+ out_features (`List[str]`, *optional*):
55
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
56
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
57
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
58
+ same order as defined in the `stage_names` attribute.
59
+ out_indices (`List[int]`, *optional*):
60
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
61
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
62
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
63
+ same order as defined in the `stage_names` attribute.
64
+
65
+ Example:
66
+ ```python
67
+ >>> from transformers import RTDetrResNetConfig, RTDetrResnetBackbone
68
+
69
+ >>> # Initializing a ResNet resnet-50 style configuration
70
+ >>> configuration = RTDetrResNetConfig()
71
+
72
+ >>> # Initializing a model (with random weights) from the resnet-50 style configuration
73
+ >>> model = RTDetrResnetBackbone(configuration)
74
+
75
+ >>> # Accessing the model configuration
76
+ >>> configuration = model.config
77
+ ```
78
+ """
79
+
80
+ model_type = "rt_detr_resnet"
81
+ layer_types = ["basic", "bottleneck"]
82
+
83
+ def __init__(
84
+ self,
85
+ num_channels=3,
86
+ embedding_size=64,
87
+ hidden_sizes=[256, 512, 1024, 2048],
88
+ depths=[3, 4, 6, 3],
89
+ layer_type="bottleneck",
90
+ hidden_act="relu",
91
+ downsample_in_first_stage=False,
92
+ downsample_in_bottleneck=False,
93
+ out_features=None,
94
+ out_indices=None,
95
+ **kwargs,
96
+ ):
97
+ super().__init__(**kwargs)
98
+ if layer_type not in self.layer_types:
99
+ raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}")
100
+ self.num_channels = num_channels
101
+ self.embedding_size = embedding_size
102
+ self.hidden_sizes = hidden_sizes
103
+ self.depths = depths
104
+ self.layer_type = layer_type
105
+ self.hidden_act = hidden_act
106
+ self.downsample_in_first_stage = downsample_in_first_stage
107
+ self.downsample_in_bottleneck = downsample_in_bottleneck
108
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
109
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
110
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
111
+ )
112
+
113
+
114
+ __all__ = ["RTDetrResNetConfig"]
.venv/lib/python3.11/site-packages/transformers/models/rt_detr/image_processing_rt_detr.py ADDED
@@ -0,0 +1,1102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for RT-DETR."""
16
+
17
+ import pathlib
18
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+
22
+ from ...feature_extraction_utils import BatchFeature
23
+ from ...image_processing_utils import BaseImageProcessor, get_size_dict
24
+ from ...image_transforms import (
25
+ PaddingMode,
26
+ center_to_corners_format,
27
+ corners_to_center_format,
28
+ pad,
29
+ rescale,
30
+ resize,
31
+ to_channel_dimension_format,
32
+ )
33
+ from ...image_utils import (
34
+ IMAGENET_DEFAULT_MEAN,
35
+ IMAGENET_DEFAULT_STD,
36
+ AnnotationFormat,
37
+ AnnotationType,
38
+ ChannelDimension,
39
+ ImageInput,
40
+ PILImageResampling,
41
+ get_image_size,
42
+ infer_channel_dimension_format,
43
+ is_scaled_image,
44
+ make_list_of_images,
45
+ to_numpy_array,
46
+ valid_images,
47
+ validate_annotations,
48
+ validate_preprocess_arguments,
49
+ )
50
+ from ...utils import (
51
+ filter_out_non_signature_kwargs,
52
+ is_flax_available,
53
+ is_jax_tensor,
54
+ is_tf_available,
55
+ is_tf_tensor,
56
+ is_torch_available,
57
+ is_torch_tensor,
58
+ logging,
59
+ requires_backends,
60
+ )
61
+ from ...utils.generic import TensorType
62
+
63
+
64
+ if is_torch_available():
65
+ import torch
66
+
67
+
68
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
69
+
70
+ SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION,)
71
+
72
+
73
+ # Copied from transformers.models.detr.image_processing_detr.get_size_with_aspect_ratio
74
+ def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, int]:
75
+ """
76
+ Computes the output image size given the input image size and the desired output size.
77
+
78
+ Args:
79
+ image_size (`Tuple[int, int]`):
80
+ The input image size.
81
+ size (`int`):
82
+ The desired output size.
83
+ max_size (`int`, *optional*):
84
+ The maximum allowed output size.
85
+ """
86
+ height, width = image_size
87
+ raw_size = None
88
+ if max_size is not None:
89
+ min_original_size = float(min((height, width)))
90
+ max_original_size = float(max((height, width)))
91
+ if max_original_size / min_original_size * size > max_size:
92
+ raw_size = max_size * min_original_size / max_original_size
93
+ size = int(round(raw_size))
94
+
95
+ if (height <= width and height == size) or (width <= height and width == size):
96
+ oh, ow = height, width
97
+ elif width < height:
98
+ ow = size
99
+ if max_size is not None and raw_size is not None:
100
+ oh = int(raw_size * height / width)
101
+ else:
102
+ oh = int(size * height / width)
103
+ else:
104
+ oh = size
105
+ if max_size is not None and raw_size is not None:
106
+ ow = int(raw_size * width / height)
107
+ else:
108
+ ow = int(size * width / height)
109
+
110
+ return (oh, ow)
111
+
112
+
113
+ # Copied from transformers.models.detr.image_processing_detr.get_resize_output_image_size
114
+ def get_resize_output_image_size(
115
+ input_image: np.ndarray,
116
+ size: Union[int, Tuple[int, int], List[int]],
117
+ max_size: Optional[int] = None,
118
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
119
+ ) -> Tuple[int, int]:
120
+ """
121
+ Computes the output image size given the input image size and the desired output size. If the desired output size
122
+ is a tuple or list, the output image size is returned as is. If the desired output size is an integer, the output
123
+ image size is computed by keeping the aspect ratio of the input image size.
124
+
125
+ Args:
126
+ input_image (`np.ndarray`):
127
+ The image to resize.
128
+ size (`int` or `Tuple[int, int]` or `List[int]`):
129
+ The desired output size.
130
+ max_size (`int`, *optional*):
131
+ The maximum allowed output size.
132
+ input_data_format (`ChannelDimension` or `str`, *optional*):
133
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
134
+ """
135
+ image_size = get_image_size(input_image, input_data_format)
136
+ if isinstance(size, (list, tuple)):
137
+ return size
138
+
139
+ return get_size_with_aspect_ratio(image_size, size, max_size)
140
+
141
+
142
+ # Copied from transformers.models.detr.image_processing_detr.get_image_size_for_max_height_width
143
+ def get_image_size_for_max_height_width(
144
+ input_image: np.ndarray,
145
+ max_height: int,
146
+ max_width: int,
147
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
148
+ ) -> Tuple[int, int]:
149
+ """
150
+ Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
151
+ Important, even if image_height < max_height and image_width < max_width, the image will be resized
152
+ to at least one of the edges be equal to max_height or max_width.
153
+ For example:
154
+ - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
155
+ - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
156
+ Args:
157
+ input_image (`np.ndarray`):
158
+ The image to resize.
159
+ max_height (`int`):
160
+ The maximum allowed height.
161
+ max_width (`int`):
162
+ The maximum allowed width.
163
+ input_data_format (`ChannelDimension` or `str`, *optional*):
164
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
165
+ """
166
+ image_size = get_image_size(input_image, input_data_format)
167
+ height, width = image_size
168
+ height_scale = max_height / height
169
+ width_scale = max_width / width
170
+ min_scale = min(height_scale, width_scale)
171
+ new_height = int(height * min_scale)
172
+ new_width = int(width * min_scale)
173
+ return new_height, new_width
174
+
175
+
176
+ # Copied from transformers.models.detr.image_processing_detr.get_numpy_to_framework_fn
177
+ def get_numpy_to_framework_fn(arr) -> Callable:
178
+ """
179
+ Returns a function that converts a numpy array to the framework of the input array.
180
+
181
+ Args:
182
+ arr (`np.ndarray`): The array to convert.
183
+ """
184
+ if isinstance(arr, np.ndarray):
185
+ return np.array
186
+ if is_tf_available() and is_tf_tensor(arr):
187
+ import tensorflow as tf
188
+
189
+ return tf.convert_to_tensor
190
+ if is_torch_available() and is_torch_tensor(arr):
191
+ import torch
192
+
193
+ return torch.tensor
194
+ if is_flax_available() and is_jax_tensor(arr):
195
+ import jax.numpy as jnp
196
+
197
+ return jnp.array
198
+ raise ValueError(f"Cannot convert arrays of type {type(arr)}")
199
+
200
+
201
+ # Copied from transformers.models.detr.image_processing_detr.safe_squeeze
202
+ def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray:
203
+ """
204
+ Squeezes an array, but only if the axis specified has dim 1.
205
+ """
206
+ if axis is None:
207
+ return arr.squeeze()
208
+
209
+ try:
210
+ return arr.squeeze(axis=axis)
211
+ except ValueError:
212
+ return arr
213
+
214
+
215
+ # Copied from transformers.models.detr.image_processing_detr.normalize_annotation
216
+ def normalize_annotation(annotation: Dict, image_size: Tuple[int, int]) -> Dict:
217
+ image_height, image_width = image_size
218
+ norm_annotation = {}
219
+ for key, value in annotation.items():
220
+ if key == "boxes":
221
+ boxes = value
222
+ boxes = corners_to_center_format(boxes)
223
+ boxes /= np.asarray([image_width, image_height, image_width, image_height], dtype=np.float32)
224
+ norm_annotation[key] = boxes
225
+ else:
226
+ norm_annotation[key] = value
227
+ return norm_annotation
228
+
229
+
230
+ # Copied from transformers.models.detr.image_processing_detr.max_across_indices
231
+ def max_across_indices(values: Iterable[Any]) -> List[Any]:
232
+ """
233
+ Return the maximum value across all indices of an iterable of values.
234
+ """
235
+ return [max(values_i) for values_i in zip(*values)]
236
+
237
+
238
+ # Copied from transformers.models.detr.image_processing_detr.get_max_height_width
239
+ def get_max_height_width(
240
+ images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
241
+ ) -> List[int]:
242
+ """
243
+ Get the maximum height and width across all images in a batch.
244
+ """
245
+ if input_data_format is None:
246
+ input_data_format = infer_channel_dimension_format(images[0])
247
+
248
+ if input_data_format == ChannelDimension.FIRST:
249
+ _, max_height, max_width = max_across_indices([img.shape for img in images])
250
+ elif input_data_format == ChannelDimension.LAST:
251
+ max_height, max_width, _ = max_across_indices([img.shape for img in images])
252
+ else:
253
+ raise ValueError(f"Invalid channel dimension format: {input_data_format}")
254
+ return (max_height, max_width)
255
+
256
+
257
+ # Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
258
+ def make_pixel_mask(
259
+ image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
260
+ ) -> np.ndarray:
261
+ """
262
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
263
+
264
+ Args:
265
+ image (`np.ndarray`):
266
+ Image to make the pixel mask for.
267
+ output_size (`Tuple[int, int]`):
268
+ Output size of the mask.
269
+ """
270
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
271
+ mask = np.zeros(output_size, dtype=np.int64)
272
+ mask[:input_height, :input_width] = 1
273
+ return mask
274
+
275
+
276
+ def prepare_coco_detection_annotation(
277
+ image,
278
+ target,
279
+ return_segmentation_masks: bool = False,
280
+ input_data_format: Optional[Union[ChannelDimension, str]] = None,
281
+ ):
282
+ """
283
+ Convert the target in COCO format into the format expected by RTDETR.
284
+ """
285
+ image_height, image_width = get_image_size(image, channel_dim=input_data_format)
286
+
287
+ image_id = target["image_id"]
288
+ image_id = np.asarray([image_id], dtype=np.int64)
289
+
290
+ # Get all COCO annotations for the given image.
291
+ annotations = target["annotations"]
292
+ annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0]
293
+
294
+ classes = [obj["category_id"] for obj in annotations]
295
+ classes = np.asarray(classes, dtype=np.int64)
296
+
297
+ # for conversion to coco api
298
+ area = np.asarray([obj["area"] for obj in annotations], dtype=np.float32)
299
+ iscrowd = np.asarray([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in annotations], dtype=np.int64)
300
+
301
+ boxes = [obj["bbox"] for obj in annotations]
302
+ # guard against no boxes via resizing
303
+ boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4)
304
+ boxes[:, 2:] += boxes[:, :2]
305
+ boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
306
+ boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
307
+
308
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
309
+
310
+ new_target = {}
311
+ new_target["image_id"] = image_id
312
+ new_target["class_labels"] = classes[keep]
313
+ new_target["boxes"] = boxes[keep]
314
+ new_target["area"] = area[keep]
315
+ new_target["iscrowd"] = iscrowd[keep]
316
+ new_target["orig_size"] = np.asarray([int(image_height), int(image_width)], dtype=np.int64)
317
+
318
+ if annotations and "keypoints" in annotations[0]:
319
+ keypoints = [obj["keypoints"] for obj in annotations]
320
+ # Converting the filtered keypoints list to a numpy array
321
+ keypoints = np.asarray(keypoints, dtype=np.float32)
322
+ # Apply the keep mask here to filter the relevant annotations
323
+ keypoints = keypoints[keep]
324
+ num_keypoints = keypoints.shape[0]
325
+ keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
326
+ new_target["keypoints"] = keypoints
327
+
328
+ return new_target
329
+
330
+
331
+ # Copied from transformers.models.detr.image_processing_detr.resize_annotation
332
+ def resize_annotation(
333
+ annotation: Dict[str, Any],
334
+ orig_size: Tuple[int, int],
335
+ target_size: Tuple[int, int],
336
+ threshold: float = 0.5,
337
+ resample: PILImageResampling = PILImageResampling.NEAREST,
338
+ ):
339
+ """
340
+ Resizes an annotation to a target size.
341
+
342
+ Args:
343
+ annotation (`Dict[str, Any]`):
344
+ The annotation dictionary.
345
+ orig_size (`Tuple[int, int]`):
346
+ The original size of the input image.
347
+ target_size (`Tuple[int, int]`):
348
+ The target size of the image, as returned by the preprocessing `resize` step.
349
+ threshold (`float`, *optional*, defaults to 0.5):
350
+ The threshold used to binarize the segmentation masks.
351
+ resample (`PILImageResampling`, defaults to `PILImageResampling.NEAREST`):
352
+ The resampling filter to use when resizing the masks.
353
+ """
354
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(target_size, orig_size))
355
+ ratio_height, ratio_width = ratios
356
+
357
+ new_annotation = {}
358
+ new_annotation["size"] = target_size
359
+
360
+ for key, value in annotation.items():
361
+ if key == "boxes":
362
+ boxes = value
363
+ scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32)
364
+ new_annotation["boxes"] = scaled_boxes
365
+ elif key == "area":
366
+ area = value
367
+ scaled_area = area * (ratio_width * ratio_height)
368
+ new_annotation["area"] = scaled_area
369
+ elif key == "masks":
370
+ masks = value[:, None]
371
+ masks = np.array([resize(mask, target_size, resample=resample) for mask in masks])
372
+ masks = masks.astype(np.float32)
373
+ masks = masks[:, 0] > threshold
374
+ new_annotation["masks"] = masks
375
+ elif key == "size":
376
+ new_annotation["size"] = target_size
377
+ else:
378
+ new_annotation[key] = value
379
+
380
+ return new_annotation
381
+
382
+
383
+ class RTDetrImageProcessor(BaseImageProcessor):
384
+ r"""
385
+ Constructs a RT-DETR image processor.
386
+
387
+ Args:
388
+ format (`str`, *optional*, defaults to `AnnotationFormat.COCO_DETECTION`):
389
+ Data format of the annotations. One of "coco_detection" or "coco_panoptic".
390
+ do_resize (`bool`, *optional*, defaults to `True`):
391
+ Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be
392
+ overridden by the `do_resize` parameter in the `preprocess` method.
393
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 640, "width": 640}`):
394
+ Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter
395
+ in the `preprocess` method. Available options are:
396
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
397
+ Do NOT keep the aspect ratio.
398
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
399
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
400
+ less or equal to `longest_edge`.
401
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
402
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
403
+ `max_width`.
404
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
405
+ Resampling filter to use if resizing the image.
406
+ do_rescale (`bool`, *optional*, defaults to `True`):
407
+ Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
408
+ `do_rescale` parameter in the `preprocess` method.
409
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
410
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
411
+ `preprocess` method.
412
+ Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
413
+ `preprocess` method.
414
+ do_normalize (`bool`, *optional*, defaults to `False`):
415
+ Whether to normalize the image.
416
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
417
+ Mean values to use when normalizing the image. Can be a single value or a list of values, one for each
418
+ channel. Can be overridden by the `image_mean` parameter in the `preprocess` method.
419
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
420
+ Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one
421
+ for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method.
422
+ do_convert_annotations (`bool`, *optional*, defaults to `True`):
423
+ Controls whether to convert the annotations to the format expected by the DETR model. Converts the
424
+ bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
425
+ Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
426
+ do_pad (`bool`, *optional*, defaults to `False`):
427
+ Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
428
+ method. If `True`, padding will be applied to the bottom and right of the image with zeros.
429
+ If `pad_size` is provided, the image will be padded to the specified dimensions.
430
+ Otherwise, the image will be padded to the maximum height and width of the batch.
431
+ pad_size (`Dict[str, int]`, *optional*):
432
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
433
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
434
+ height and width in the batch.
435
+ """
436
+
437
+ model_input_names = ["pixel_values", "pixel_mask"]
438
+
439
+ def __init__(
440
+ self,
441
+ format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION,
442
+ do_resize: bool = True,
443
+ size: Dict[str, int] = None,
444
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
445
+ do_rescale: bool = True,
446
+ rescale_factor: Union[int, float] = 1 / 255,
447
+ do_normalize: bool = False,
448
+ image_mean: Union[float, List[float]] = None,
449
+ image_std: Union[float, List[float]] = None,
450
+ do_convert_annotations: bool = True,
451
+ do_pad: bool = False,
452
+ pad_size: Optional[Dict[str, int]] = None,
453
+ **kwargs,
454
+ ) -> None:
455
+ size = size if size is not None else {"height": 640, "width": 640}
456
+ size = get_size_dict(size, default_to_square=False)
457
+
458
+ if do_convert_annotations is None:
459
+ do_convert_annotations = do_normalize
460
+
461
+ super().__init__(**kwargs)
462
+ self.format = format
463
+ self.do_resize = do_resize
464
+ self.size = size
465
+ self.resample = resample
466
+ self.do_rescale = do_rescale
467
+ self.rescale_factor = rescale_factor
468
+ self.do_normalize = do_normalize
469
+ self.do_convert_annotations = do_convert_annotations
470
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
471
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
472
+ self.do_pad = do_pad
473
+ self.pad_size = pad_size
474
+
475
+ def prepare_annotation(
476
+ self,
477
+ image: np.ndarray,
478
+ target: Dict,
479
+ format: Optional[AnnotationFormat] = None,
480
+ return_segmentation_masks: bool = None,
481
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
482
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
483
+ ) -> Dict:
484
+ """
485
+ Prepare an annotation for feeding into RTDETR model.
486
+ """
487
+ format = format if format is not None else self.format
488
+
489
+ if format == AnnotationFormat.COCO_DETECTION:
490
+ return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
491
+ target = prepare_coco_detection_annotation(
492
+ image, target, return_segmentation_masks, input_data_format=input_data_format
493
+ )
494
+ else:
495
+ raise ValueError(f"Format {format} is not supported.")
496
+ return target
497
+
498
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize
499
+ def resize(
500
+ self,
501
+ image: np.ndarray,
502
+ size: Dict[str, int],
503
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
504
+ data_format: Optional[ChannelDimension] = None,
505
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
506
+ **kwargs,
507
+ ) -> np.ndarray:
508
+ """
509
+ Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
510
+ int, smaller edge of the image will be matched to this number.
511
+
512
+ Args:
513
+ image (`np.ndarray`):
514
+ Image to resize.
515
+ size (`Dict[str, int]`):
516
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
517
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
518
+ Do NOT keep the aspect ratio.
519
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
520
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
521
+ less or equal to `longest_edge`.
522
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
523
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
524
+ `max_width`.
525
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
526
+ Resampling filter to use if resizing the image.
527
+ data_format (`str` or `ChannelDimension`, *optional*):
528
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
529
+ image is used.
530
+ input_data_format (`ChannelDimension` or `str`, *optional*):
531
+ The channel dimension format of the input image. If not provided, it will be inferred.
532
+ """
533
+ if "max_size" in kwargs:
534
+ logger.warning_once(
535
+ "The `max_size` parameter is deprecated and will be removed in v4.26. "
536
+ "Please specify in `size['longest_edge'] instead`.",
537
+ )
538
+ max_size = kwargs.pop("max_size")
539
+ else:
540
+ max_size = None
541
+ size = get_size_dict(size, max_size=max_size, default_to_square=False)
542
+ if "shortest_edge" in size and "longest_edge" in size:
543
+ new_size = get_resize_output_image_size(
544
+ image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
545
+ )
546
+ elif "max_height" in size and "max_width" in size:
547
+ new_size = get_image_size_for_max_height_width(
548
+ image, size["max_height"], size["max_width"], input_data_format=input_data_format
549
+ )
550
+ elif "height" in size and "width" in size:
551
+ new_size = (size["height"], size["width"])
552
+ else:
553
+ raise ValueError(
554
+ "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
555
+ f" {size.keys()}."
556
+ )
557
+ image = resize(
558
+ image,
559
+ size=new_size,
560
+ resample=resample,
561
+ data_format=data_format,
562
+ input_data_format=input_data_format,
563
+ **kwargs,
564
+ )
565
+ return image
566
+
567
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.resize_annotation
568
+ def resize_annotation(
569
+ self,
570
+ annotation,
571
+ orig_size,
572
+ size,
573
+ resample: PILImageResampling = PILImageResampling.NEAREST,
574
+ ) -> Dict:
575
+ """
576
+ Resize the annotation to match the resized image. If size is an int, smaller edge of the mask will be matched
577
+ to this number.
578
+ """
579
+ return resize_annotation(annotation, orig_size=orig_size, target_size=size, resample=resample)
580
+
581
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.rescale
582
+ def rescale(
583
+ self,
584
+ image: np.ndarray,
585
+ rescale_factor: float,
586
+ data_format: Optional[Union[str, ChannelDimension]] = None,
587
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
588
+ ) -> np.ndarray:
589
+ """
590
+ Rescale the image by the given factor. image = image * rescale_factor.
591
+
592
+ Args:
593
+ image (`np.ndarray`):
594
+ Image to rescale.
595
+ rescale_factor (`float`):
596
+ The value to use for rescaling.
597
+ data_format (`str` or `ChannelDimension`, *optional*):
598
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
599
+ image is used. Can be one of:
600
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
601
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
602
+ input_data_format (`str` or `ChannelDimension`, *optional*):
603
+ The channel dimension format for the input image. If unset, is inferred from the input image. Can be
604
+ one of:
605
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
606
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
607
+ """
608
+ return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
609
+
610
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize_annotation
611
+ def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict:
612
+ """
613
+ Normalize the boxes in the annotation from `[top_left_x, top_left_y, bottom_right_x, bottom_right_y]` to
614
+ `[center_x, center_y, width, height]` format and from absolute to relative pixel values.
615
+ """
616
+ return normalize_annotation(annotation, image_size=image_size)
617
+
618
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._update_annotation_for_padded_image
619
+ def _update_annotation_for_padded_image(
620
+ self,
621
+ annotation: Dict,
622
+ input_image_size: Tuple[int, int],
623
+ output_image_size: Tuple[int, int],
624
+ padding,
625
+ update_bboxes,
626
+ ) -> Dict:
627
+ """
628
+ Update the annotation for a padded image.
629
+ """
630
+ new_annotation = {}
631
+ new_annotation["size"] = output_image_size
632
+
633
+ for key, value in annotation.items():
634
+ if key == "masks":
635
+ masks = value
636
+ masks = pad(
637
+ masks,
638
+ padding,
639
+ mode=PaddingMode.CONSTANT,
640
+ constant_values=0,
641
+ input_data_format=ChannelDimension.FIRST,
642
+ )
643
+ masks = safe_squeeze(masks, 1)
644
+ new_annotation["masks"] = masks
645
+ elif key == "boxes" and update_bboxes:
646
+ boxes = value
647
+ boxes *= np.asarray(
648
+ [
649
+ input_image_size[1] / output_image_size[1],
650
+ input_image_size[0] / output_image_size[0],
651
+ input_image_size[1] / output_image_size[1],
652
+ input_image_size[0] / output_image_size[0],
653
+ ]
654
+ )
655
+ new_annotation["boxes"] = boxes
656
+ elif key == "size":
657
+ new_annotation["size"] = output_image_size
658
+ else:
659
+ new_annotation[key] = value
660
+ return new_annotation
661
+
662
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image
663
+ def _pad_image(
664
+ self,
665
+ image: np.ndarray,
666
+ output_size: Tuple[int, int],
667
+ annotation: Optional[Dict[str, Any]] = None,
668
+ constant_values: Union[float, Iterable[float]] = 0,
669
+ data_format: Optional[ChannelDimension] = None,
670
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
671
+ update_bboxes: bool = True,
672
+ ) -> np.ndarray:
673
+ """
674
+ Pad an image with zeros to the given size.
675
+ """
676
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
677
+ output_height, output_width = output_size
678
+
679
+ pad_bottom = output_height - input_height
680
+ pad_right = output_width - input_width
681
+ padding = ((0, pad_bottom), (0, pad_right))
682
+ padded_image = pad(
683
+ image,
684
+ padding,
685
+ mode=PaddingMode.CONSTANT,
686
+ constant_values=constant_values,
687
+ data_format=data_format,
688
+ input_data_format=input_data_format,
689
+ )
690
+ if annotation is not None:
691
+ annotation = self._update_annotation_for_padded_image(
692
+ annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes
693
+ )
694
+ return padded_image, annotation
695
+
696
+ # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.pad
697
+ def pad(
698
+ self,
699
+ images: List[np.ndarray],
700
+ annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None,
701
+ constant_values: Union[float, Iterable[float]] = 0,
702
+ return_pixel_mask: bool = True,
703
+ return_tensors: Optional[Union[str, TensorType]] = None,
704
+ data_format: Optional[ChannelDimension] = None,
705
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
706
+ update_bboxes: bool = True,
707
+ pad_size: Optional[Dict[str, int]] = None,
708
+ ) -> BatchFeature:
709
+ """
710
+ Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
711
+ in the batch and optionally returns their corresponding pixel mask.
712
+
713
+ Args:
714
+ images (List[`np.ndarray`]):
715
+ Images to pad.
716
+ annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
717
+ Annotations to transform according to the padding that is applied to the images.
718
+ constant_values (`float` or `Iterable[float]`, *optional*):
719
+ The value to use for the padding if `mode` is `"constant"`.
720
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
721
+ Whether to return a pixel mask.
722
+ return_tensors (`str` or `TensorType`, *optional*):
723
+ The type of tensors to return. Can be one of:
724
+ - Unset: Return a list of `np.ndarray`.
725
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
726
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
727
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
728
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
729
+ data_format (`str` or `ChannelDimension`, *optional*):
730
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
731
+ input_data_format (`ChannelDimension` or `str`, *optional*):
732
+ The channel dimension format of the input image. If not provided, it will be inferred.
733
+ update_bboxes (`bool`, *optional*, defaults to `True`):
734
+ Whether to update the bounding boxes in the annotations to match the padded images. If the
735
+ bounding boxes have not been converted to relative coordinates and `(centre_x, centre_y, width, height)`
736
+ format, the bounding boxes will not be updated.
737
+ pad_size (`Dict[str, int]`, *optional*):
738
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
739
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
740
+ height and width in the batch.
741
+ """
742
+ pad_size = pad_size if pad_size is not None else self.pad_size
743
+ if pad_size is not None:
744
+ padded_size = (pad_size["height"], pad_size["width"])
745
+ else:
746
+ padded_size = get_max_height_width(images, input_data_format=input_data_format)
747
+
748
+ annotation_list = annotations if annotations is not None else [None] * len(images)
749
+ padded_images = []
750
+ padded_annotations = []
751
+ for image, annotation in zip(images, annotation_list):
752
+ padded_image, padded_annotation = self._pad_image(
753
+ image,
754
+ padded_size,
755
+ annotation,
756
+ constant_values=constant_values,
757
+ data_format=data_format,
758
+ input_data_format=input_data_format,
759
+ update_bboxes=update_bboxes,
760
+ )
761
+ padded_images.append(padded_image)
762
+ padded_annotations.append(padded_annotation)
763
+
764
+ data = {"pixel_values": padded_images}
765
+
766
+ if return_pixel_mask:
767
+ masks = [
768
+ make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format)
769
+ for image in images
770
+ ]
771
+ data["pixel_mask"] = masks
772
+
773
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
774
+
775
+ if annotations is not None:
776
+ encoded_inputs["labels"] = [
777
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in padded_annotations
778
+ ]
779
+
780
+ return encoded_inputs
781
+
782
+ @filter_out_non_signature_kwargs()
783
+ def preprocess(
784
+ self,
785
+ images: ImageInput,
786
+ annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None,
787
+ return_segmentation_masks: bool = None,
788
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
789
+ do_resize: Optional[bool] = None,
790
+ size: Optional[Dict[str, int]] = None,
791
+ resample=None, # PILImageResampling
792
+ do_rescale: Optional[bool] = None,
793
+ rescale_factor: Optional[Union[int, float]] = None,
794
+ do_normalize: Optional[bool] = None,
795
+ do_convert_annotations: Optional[bool] = None,
796
+ image_mean: Optional[Union[float, List[float]]] = None,
797
+ image_std: Optional[Union[float, List[float]]] = None,
798
+ do_pad: Optional[bool] = None,
799
+ format: Optional[Union[str, AnnotationFormat]] = None,
800
+ return_tensors: Optional[Union[TensorType, str]] = None,
801
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
802
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
803
+ pad_size: Optional[Dict[str, int]] = None,
804
+ ) -> BatchFeature:
805
+ """
806
+ Preprocess an image or a batch of images so that it can be used by the model.
807
+
808
+ Args:
809
+ images (`ImageInput`):
810
+ Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging
811
+ from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`.
812
+ annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
813
+ List of annotations associated with the image or batch of images. If annotation is for object
814
+ detection, the annotations should be a dictionary with the following keys:
815
+ - "image_id" (`int`): The image id.
816
+ - "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
817
+ dictionary. An image can have no annotations, in which case the list should be empty.
818
+ If annotation is for segmentation, the annotations should be a dictionary with the following keys:
819
+ - "image_id" (`int`): The image id.
820
+ - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
821
+ An image can have no segments, in which case the list should be empty.
822
+ - "file_name" (`str`): The file name of the image.
823
+ return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks):
824
+ Whether to return segmentation masks.
825
+ masks_path (`str` or `pathlib.Path`, *optional*):
826
+ Path to the directory containing the segmentation masks.
827
+ do_resize (`bool`, *optional*, defaults to self.do_resize):
828
+ Whether to resize the image.
829
+ size (`Dict[str, int]`, *optional*, defaults to self.size):
830
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
831
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
832
+ Do NOT keep the aspect ratio.
833
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
834
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
835
+ less or equal to `longest_edge`.
836
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
837
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
838
+ `max_width`.
839
+ resample (`PILImageResampling`, *optional*, defaults to self.resample):
840
+ Resampling filter to use when resizing the image.
841
+ do_rescale (`bool`, *optional*, defaults to self.do_rescale):
842
+ Whether to rescale the image.
843
+ rescale_factor (`float`, *optional*, defaults to self.rescale_factor):
844
+ Rescale factor to use when rescaling the image.
845
+ do_normalize (`bool`, *optional*, defaults to self.do_normalize):
846
+ Whether to normalize the image.
847
+ do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations):
848
+ Whether to convert the annotations to the format expected by the model. Converts the bounding
849
+ boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)`
850
+ and in relative coordinates.
851
+ image_mean (`float` or `List[float]`, *optional*, defaults to self.image_mean):
852
+ Mean to use when normalizing the image.
853
+ image_std (`float` or `List[float]`, *optional*, defaults to self.image_std):
854
+ Standard deviation to use when normalizing the image.
855
+ do_pad (`bool`, *optional*, defaults to self.do_pad):
856
+ Whether to pad the image. If `True`, padding will be applied to the bottom and right of
857
+ the image with zeros. If `pad_size` is provided, the image will be padded to the specified
858
+ dimensions. Otherwise, the image will be padded to the maximum height and width of the batch.
859
+ format (`str` or `AnnotationFormat`, *optional*, defaults to self.format):
860
+ Format of the annotations.
861
+ return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
862
+ Type of tensors to return. If `None`, will return the list of images.
863
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
864
+ The channel dimension format for the output image. Can be one of:
865
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
866
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
867
+ - Unset: Use the channel dimension format of the input image.
868
+ input_data_format (`ChannelDimension` or `str`, *optional*):
869
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
870
+ from the input image. Can be one of:
871
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
872
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
873
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
874
+ pad_size (`Dict[str, int]`, *optional*):
875
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
876
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
877
+ height and width in the batch.
878
+ """
879
+ do_resize = self.do_resize if do_resize is None else do_resize
880
+ size = self.size if size is None else size
881
+ size = get_size_dict(size=size, default_to_square=True)
882
+ resample = self.resample if resample is None else resample
883
+ do_rescale = self.do_rescale if do_rescale is None else do_rescale
884
+ rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
885
+ do_normalize = self.do_normalize if do_normalize is None else do_normalize
886
+ image_mean = self.image_mean if image_mean is None else image_mean
887
+ image_std = self.image_std if image_std is None else image_std
888
+ do_convert_annotations = (
889
+ self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations
890
+ )
891
+ do_pad = self.do_pad if do_pad is None else do_pad
892
+ pad_size = self.pad_size if pad_size is None else pad_size
893
+ format = self.format if format is None else format
894
+
895
+ images = make_list_of_images(images)
896
+
897
+ if not valid_images(images):
898
+ raise ValueError(
899
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
900
+ "torch.Tensor, tf.Tensor or jax.ndarray."
901
+ )
902
+
903
+ # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated.
904
+
905
+ validate_preprocess_arguments(
906
+ do_rescale=do_rescale,
907
+ rescale_factor=rescale_factor,
908
+ do_normalize=do_normalize,
909
+ image_mean=image_mean,
910
+ image_std=image_std,
911
+ do_resize=do_resize,
912
+ size=size,
913
+ resample=resample,
914
+ )
915
+
916
+ if annotations is not None and isinstance(annotations, dict):
917
+ annotations = [annotations]
918
+
919
+ if annotations is not None and len(images) != len(annotations):
920
+ raise ValueError(
921
+ f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
922
+ )
923
+
924
+ format = AnnotationFormat(format)
925
+ if annotations is not None:
926
+ validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
927
+
928
+ images = make_list_of_images(images)
929
+ if not valid_images(images):
930
+ raise ValueError(
931
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
932
+ "torch.Tensor, tf.Tensor or jax.ndarray."
933
+ )
934
+
935
+ # All transformations expect numpy arrays
936
+ images = [to_numpy_array(image) for image in images]
937
+
938
+ if do_rescale and is_scaled_image(images[0]):
939
+ logger.warning_once(
940
+ "It looks like you are trying to rescale already rescaled images. If the input"
941
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
942
+ )
943
+
944
+ if input_data_format is None:
945
+ # We assume that all images have the same channel dimension format.
946
+ input_data_format = infer_channel_dimension_format(images[0])
947
+
948
+ # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
949
+ if annotations is not None:
950
+ prepared_images = []
951
+ prepared_annotations = []
952
+ for image, target in zip(images, annotations):
953
+ target = self.prepare_annotation(
954
+ image,
955
+ target,
956
+ format,
957
+ return_segmentation_masks=return_segmentation_masks,
958
+ masks_path=masks_path,
959
+ input_data_format=input_data_format,
960
+ )
961
+ prepared_images.append(image)
962
+ prepared_annotations.append(target)
963
+ images = prepared_images
964
+ annotations = prepared_annotations
965
+ del prepared_images, prepared_annotations
966
+
967
+ # transformations
968
+ if do_resize:
969
+ if annotations is not None:
970
+ resized_images, resized_annotations = [], []
971
+ for image, target in zip(images, annotations):
972
+ orig_size = get_image_size(image, input_data_format)
973
+ resized_image = self.resize(
974
+ image, size=size, resample=resample, input_data_format=input_data_format
975
+ )
976
+ resized_annotation = self.resize_annotation(
977
+ target, orig_size, get_image_size(resized_image, input_data_format)
978
+ )
979
+ resized_images.append(resized_image)
980
+ resized_annotations.append(resized_annotation)
981
+ images = resized_images
982
+ annotations = resized_annotations
983
+ del resized_images, resized_annotations
984
+ else:
985
+ images = [
986
+ self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
987
+ for image in images
988
+ ]
989
+
990
+ if do_rescale:
991
+ images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
992
+
993
+ if do_normalize:
994
+ images = [
995
+ self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
996
+ ]
997
+
998
+ if do_convert_annotations and annotations is not None:
999
+ annotations = [
1000
+ self.normalize_annotation(annotation, get_image_size(image, input_data_format))
1001
+ for annotation, image in zip(annotations, images)
1002
+ ]
1003
+
1004
+ if do_pad:
1005
+ # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
1006
+ encoded_inputs = self.pad(
1007
+ images,
1008
+ annotations=annotations,
1009
+ return_pixel_mask=True,
1010
+ data_format=data_format,
1011
+ input_data_format=input_data_format,
1012
+ update_bboxes=do_convert_annotations,
1013
+ return_tensors=return_tensors,
1014
+ pad_size=pad_size,
1015
+ )
1016
+ else:
1017
+ images = [
1018
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
1019
+ for image in images
1020
+ ]
1021
+ encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
1022
+ if annotations is not None:
1023
+ encoded_inputs["labels"] = [
1024
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
1025
+ ]
1026
+
1027
+ return encoded_inputs
1028
+
1029
+ def post_process_object_detection(
1030
+ self,
1031
+ outputs,
1032
+ threshold: float = 0.5,
1033
+ target_sizes: Union[TensorType, List[Tuple]] = None,
1034
+ use_focal_loss: bool = True,
1035
+ ):
1036
+ """
1037
+ Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
1038
+ bottom_right_x, bottom_right_y) format. Only supports PyTorch.
1039
+
1040
+ Args:
1041
+ outputs ([`DetrObjectDetectionOutput`]):
1042
+ Raw outputs of the model.
1043
+ threshold (`float`, *optional*, defaults to 0.5):
1044
+ Score threshold to keep object detection predictions.
1045
+ target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
1046
+ Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
1047
+ `(height, width)` of each image in the batch. If unset, predictions will not be resized.
1048
+ use_focal_loss (`bool` defaults to `True`):
1049
+ Variable informing if the focal loss was used to predict the outputs. If `True`, a sigmoid is applied
1050
+ to compute the scores of each detection, otherwise, a softmax function is used.
1051
+
1052
+ Returns:
1053
+ `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
1054
+ in the batch as predicted by the model.
1055
+ """
1056
+ requires_backends(self, ["torch"])
1057
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
1058
+ # convert from relative cxcywh to absolute xyxy
1059
+ boxes = center_to_corners_format(out_bbox)
1060
+ if target_sizes is not None:
1061
+ if len(out_logits) != len(target_sizes):
1062
+ raise ValueError(
1063
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
1064
+ )
1065
+ if isinstance(target_sizes, List):
1066
+ img_h, img_w = torch.as_tensor(target_sizes).unbind(1)
1067
+ else:
1068
+ img_h, img_w = target_sizes.unbind(1)
1069
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
1070
+ boxes = boxes * scale_fct[:, None, :]
1071
+
1072
+ num_top_queries = out_logits.shape[1]
1073
+ num_classes = out_logits.shape[2]
1074
+
1075
+ if use_focal_loss:
1076
+ scores = torch.nn.functional.sigmoid(out_logits)
1077
+ scores, index = torch.topk(scores.flatten(1), num_top_queries, axis=-1)
1078
+ labels = index % num_classes
1079
+ index = index // num_classes
1080
+ boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
1081
+ else:
1082
+ scores = torch.nn.functional.softmax(out_logits)[:, :, :-1]
1083
+ scores, labels = scores.max(dim=-1)
1084
+ if scores.shape[1] > num_top_queries:
1085
+ scores, index = torch.topk(scores, num_top_queries, dim=-1)
1086
+ labels = torch.gather(labels, dim=1, index=index)
1087
+ boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))
1088
+
1089
+ results = []
1090
+ for score, label, box in zip(scores, labels, boxes):
1091
+ results.append(
1092
+ {
1093
+ "scores": score[score > threshold],
1094
+ "labels": label[score > threshold],
1095
+ "boxes": box[score > threshold],
1096
+ }
1097
+ )
1098
+
1099
+ return results
1100
+
1101
+
1102
+ __all__ = ["RTDetrImageProcessor"]