PeteBleackley commited on
Commit
8ae4e06
·
verified ·
1 Parent(s): 0218456

End of training

Browse files
README.md CHANGED
@@ -10,6 +10,8 @@ metrics:
10
  - precision
11
  - recall
12
  - f1
 
 
13
  model-index:
14
  - name: DisamBertCrossEncoder-base
15
  results: []
@@ -20,13 +22,14 @@ should probably proofread and complete it, then remove this comment. -->
20
 
21
  # DisamBertCrossEncoder-base
22
 
23
- This model is a fine-tuned version of [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) on the semcor dataset.
24
  It achieves the following results on the evaluation set:
25
- - Loss: 13.9274
26
- - Precision: 0.6274
27
- - Recall: 0.6398
28
- - F1: 0.6335
29
- - Matthews: 0.6392
 
30
 
31
  ## Model description
32
 
@@ -45,35 +48,36 @@ More information needed
45
  ### Training hyperparameters
46
 
47
  The following hyperparameters were used during training:
48
- - learning_rate: 0.0001
49
  - train_batch_size: 64
50
  - eval_batch_size: 64
51
  - seed: 42
 
 
52
  - optimizer: Use OptimizerNames.ADAMW_TORCH_FUSED with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
53
- - lr_scheduler_type: inverse_sqrt
54
- - lr_scheduler_warmup_steps: 1000
55
  - num_epochs: 10
56
 
57
  ### Training results
58
 
59
- | Training Loss | Epoch | Step | Validation Loss | Precision | Recall | F1 | Matthews |
60
- |:-------------:|:-----:|:-----:|:---------------:|:---------:|:------:|:------:|:--------:|
61
- | No log | 0 | 0 | 427.8458 | 0.5014 | 0.4790 | 0.4899 | 0.4781 |
62
- | 9.0689 | 1.0 | 3504 | 15.5744 | 0.6010 | 0.6196 | 0.6102 | 0.6190 |
63
- | 8.5420 | 2.0 | 7008 | 15.4129 | 0.6088 | 0.6253 | 0.6170 | 0.6247 |
64
- | 7.8106 | 3.0 | 10512 | 14.3562 | 0.6138 | 0.6328 | 0.6232 | 0.6322 |
65
- | 7.6303 | 4.0 | 14016 | 13.9741 | 0.6157 | 0.6372 | 0.6262 | 0.6366 |
66
- | 7.6930 | 5.0 | 17520 | 13.8324 | 0.6262 | 0.6402 | 0.6331 | 0.6397 |
67
- | 7.4897 | 6.0 | 21024 | 13.9649 | 0.6144 | 0.6323 | 0.6232 | 0.6318 |
68
- | 7.3819 | 7.0 | 24528 | 13.4877 | 0.6273 | 0.6407 | 0.6339 | 0.6401 |
69
- | 7.4083 | 8.0 | 28032 | 13.7249 | 0.6321 | 0.6402 | 0.6362 | 0.6397 |
70
- | 7.0140 | 9.0 | 31536 | 13.5219 | 0.6168 | 0.6389 | 0.6277 | 0.6383 |
71
- | 7.7287 | 10.0 | 35040 | 13.9274 | 0.6274 | 0.6398 | 0.6335 | 0.6392 |
72
 
73
 
74
  ### Framework versions
75
 
76
- - Transformers 5.2.0
77
  - Pytorch 2.10.0+cu128
78
  - Datasets 4.5.0
79
  - Tokenizers 0.22.2
 
10
  - precision
11
  - recall
12
  - f1
13
+ - accuracy
14
+ - matthews_correlation
15
  model-index:
16
  - name: DisamBertCrossEncoder-base
17
  results: []
 
22
 
23
  # DisamBertCrossEncoder-base
24
 
25
+ This model is a fine-tuned version of [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base) on the None dataset.
26
  It achieves the following results on the evaluation set:
27
+ - Loss: 0.9841
28
+ - Precision: 0.6896
29
+ - Recall: 0.6396
30
+ - F1: 0.6636
31
+ - Accuracy: 0.9412
32
+ - Matthews Correlation: 0.6320
33
 
34
  ## Model description
35
 
 
48
  ### Training hyperparameters
49
 
50
  The following hyperparameters were used during training:
51
+ - learning_rate: 1e-05
52
  - train_batch_size: 64
53
  - eval_batch_size: 64
54
  - seed: 42
55
+ - gradient_accumulation_steps: 5
56
+ - total_train_batch_size: 320
57
  - optimizer: Use OptimizerNames.ADAMW_TORCH_FUSED with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
58
+ - lr_scheduler_type: cosine
 
59
  - num_epochs: 10
60
 
61
  ### Training results
62
 
63
+ | Training Loss | Epoch | Step | Validation Loss | Precision | Recall | F1 | Accuracy | Matthews Correlation |
64
+ |:-------------:|:-----:|:------:|:---------------:|:---------:|:------:|:------:|:--------:|:--------------------:|
65
+ | No log | 0 | 0 | 430.2531 | 0.0905 | 0.9978 | 0.1660 | 0.0911 | -0.0157 |
66
+ | 0.0672 | 1.0 | 12551 | 0.1555 | 0.6786 | 0.5846 | 0.6281 | 0.9372 | 0.5960 |
67
+ | 0.0550 | 2.0 | 25102 | 0.1447 | 0.7176 | 0.6813 | 0.6990 | 0.9468 | 0.6701 |
68
+ | 0.0427 | 3.0 | 37653 | 0.1498 | 0.7690 | 0.6440 | 0.7010 | 0.9502 | 0.6772 |
69
+ | 0.0309 | 4.0 | 50204 | 0.1779 | 0.6773 | 0.7011 | 0.6890 | 0.9426 | 0.6575 |
70
+ | 0.0179 | 5.0 | 62755 | 0.2554 | 0.7021 | 0.6681 | 0.6847 | 0.9442 | 0.6543 |
71
+ | 0.0092 | 6.0 | 75306 | 0.3257 | 0.6927 | 0.6637 | 0.6779 | 0.9428 | 0.6467 |
72
+ | 0.0047 | 7.0 | 87857 | 0.4757 | 0.6674 | 0.6791 | 0.6732 | 0.9402 | 0.6403 |
73
+ | 0.0022 | 8.0 | 100408 | 0.6664 | 0.6943 | 0.6440 | 0.6682 | 0.9420 | 0.6370 |
74
+ | 0.0011 | 9.0 | 112959 | 0.8230 | 0.6872 | 0.6374 | 0.6613 | 0.9408 | 0.6295 |
75
+ | 0.0009 | 10.0 | 125510 | 0.9841 | 0.6896 | 0.6396 | 0.6636 | 0.9412 | 0.6320 |
76
 
77
 
78
  ### Framework versions
79
 
80
+ - Transformers 5.3.0
81
  - Pytorch 2.10.0+cu128
82
  - Datasets 4.5.0
83
  - Tokenizers 0.22.2
config.json CHANGED
@@ -1,12 +1,9 @@
1
  {
2
  "architectures": [
3
- "DisamBertSingleSense"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
- "auto_map": {
8
- "AutoModel": "DisamBertSingleSense.DisamBertSingleSense"
9
- },
10
  "bos_token_id": null,
11
  "classifier_activation": "gelu",
12
  "classifier_bias": false,
@@ -15,12 +12,10 @@
15
  "cls_token_id": 50281,
16
  "decoder_bias": true,
17
  "deterministic_flash_attn": false,
18
- "dtype": "bfloat16",
19
  "embedding_dropout": 0.0,
20
- "end_token": 50369,
21
  "eos_token_id": null,
22
  "global_attn_every_n_layers": 3,
23
- "gloss_token": 50370,
24
  "gradient_checkpointing": false,
25
  "hidden_activation": "gelu",
26
  "hidden_size": 768,
@@ -77,9 +72,9 @@
77
  "sep_token_id": 50282,
78
  "sparse_pred_ignore_index": -100,
79
  "sparse_prediction": false,
80
- "start_token": 50368,
81
  "tie_word_embeddings": true,
82
- "transformers_version": "5.2.0",
 
83
  "use_cache": false,
84
- "vocab_size": 50371
85
  }
 
1
  {
2
  "architectures": [
3
+ "DisamBertCrossEncoder"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
 
 
 
7
  "bos_token_id": null,
8
  "classifier_activation": "gelu",
9
  "classifier_bias": false,
 
12
  "cls_token_id": 50281,
13
  "decoder_bias": true,
14
  "deterministic_flash_attn": false,
15
+ "dtype": "float32",
16
  "embedding_dropout": 0.0,
 
17
  "eos_token_id": null,
18
  "global_attn_every_n_layers": 3,
 
19
  "gradient_checkpointing": false,
20
  "hidden_activation": "gelu",
21
  "hidden_size": 768,
 
72
  "sep_token_id": 50282,
73
  "sparse_pred_ignore_index": -100,
74
  "sparse_prediction": false,
 
75
  "tie_word_embeddings": true,
76
+ "tokenizer_path": "answerdotai/ModernBERT-base",
77
+ "transformers_version": "5.3.0",
78
  "use_cache": false,
79
+ "vocab_size": 50368
80
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:16fd64c635d8aa42aea9369b6196dc02bae7771edbb9dc4db1231ab774844f91
3
- size 298047648
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11ae98a5fdc6ef95a0c62701efd72e8656fa76263c9683a3dc6fd26a7b8e0df1
3
+ size 596071480
modeling_deberta_v2.py ADDED
@@ -0,0 +1,1364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Microsoft and the Hugging Face Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """PyTorch DeBERTa-v2 model."""
15
+
16
+ from collections.abc import Sequence
17
+
18
+ import torch
19
+ from torch import nn
20
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
21
+
22
+ from ... import initialization as init
23
+ from ...activations import ACT2FN
24
+ from ...modeling_layers import GradientCheckpointingLayer
25
+ from ...modeling_outputs import (
26
+ BaseModelOutput,
27
+ MaskedLMOutput,
28
+ MultipleChoiceModelOutput,
29
+ QuestionAnsweringModelOutput,
30
+ SequenceClassifierOutput,
31
+ TokenClassifierOutput,
32
+ )
33
+ from ...modeling_utils import PreTrainedModel
34
+ from ...utils import auto_docstring, logging
35
+ from .configuration_deberta_v2 import DebertaV2Config
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+
41
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm
42
+ class DebertaV2SelfOutput(nn.Module):
43
+ def __init__(self, config):
44
+ super().__init__()
45
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
46
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
47
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
48
+
49
+ def forward(self, hidden_states, input_tensor):
50
+ hidden_states = self.dense(hidden_states)
51
+ hidden_states = self.dropout(hidden_states)
52
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
53
+ return hidden_states
54
+
55
+
56
+ @torch.jit.script
57
+ def make_log_bucket_position(relative_pos, bucket_size: int, max_position: int):
58
+ sign = torch.sign(relative_pos)
59
+ mid = bucket_size // 2
60
+ abs_pos = torch.where(
61
+ (relative_pos < mid) & (relative_pos > -mid),
62
+ torch.tensor(mid - 1).type_as(relative_pos),
63
+ torch.abs(relative_pos),
64
+ )
65
+ log_pos = (
66
+ torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
67
+ )
68
+ bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
69
+ return bucket_pos
70
+
71
+
72
+ def build_relative_position(query_layer, key_layer, bucket_size: int = -1, max_position: int = -1):
73
+ """
74
+ Build relative position according to the query and key
75
+
76
+ We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
77
+ \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
78
+ P_k\\)
79
+
80
+ Args:
81
+ query_size (int): the length of query
82
+ key_size (int): the length of key
83
+ bucket_size (int): the size of position bucket
84
+ max_position (int): the maximum allowed absolute position
85
+ device (`torch.device`): the device on which tensors will be created.
86
+
87
+ Return:
88
+ `torch.LongTensor`: A tensor with shape [1, query_size, key_size]
89
+ """
90
+ query_size = query_layer.size(-2)
91
+ key_size = key_layer.size(-2)
92
+
93
+ q_ids = torch.arange(query_size, dtype=torch.long, device=query_layer.device)
94
+ k_ids = torch.arange(key_size, dtype=torch.long, device=key_layer.device)
95
+ rel_pos_ids = q_ids[:, None] - k_ids[None, :]
96
+ if bucket_size > 0 and max_position > 0:
97
+ rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
98
+ rel_pos_ids = rel_pos_ids.to(torch.long)
99
+ rel_pos_ids = rel_pos_ids[:query_size, :]
100
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
101
+ return rel_pos_ids
102
+
103
+
104
+ @torch.jit.script
105
+ # Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand
106
+ def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
107
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
108
+
109
+
110
+ @torch.jit.script
111
+ # Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand
112
+ def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
113
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
114
+
115
+
116
+ @torch.jit.script
117
+ # Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand
118
+ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
119
+ return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
120
+
121
+
122
+ @torch.jit.script
123
+ def scaled_size_sqrt(query_layer: torch.Tensor, scale_factor: int):
124
+ return torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
125
+
126
+
127
+ @torch.jit.script
128
+ def build_rpos(query_layer, key_layer, relative_pos, position_buckets: int, max_relative_positions: int):
129
+ if key_layer.size(-2) != query_layer.size(-2):
130
+ return build_relative_position(
131
+ key_layer,
132
+ key_layer,
133
+ bucket_size=position_buckets,
134
+ max_position=max_relative_positions,
135
+ )
136
+ else:
137
+ return relative_pos
138
+
139
+
140
+ class DisentangledSelfAttention(nn.Module):
141
+ """
142
+ Disentangled self-attention module
143
+
144
+ Parameters:
145
+ config (`DebertaV2Config`):
146
+ A model config class instance with the configuration to build a new model. The schema is similar to
147
+ *BertConfig*, for more details, please refer [`DebertaV2Config`]
148
+
149
+ """
150
+
151
+ def __init__(self, config):
152
+ super().__init__()
153
+ if config.hidden_size % config.num_attention_heads != 0:
154
+ raise ValueError(
155
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
156
+ f"heads ({config.num_attention_heads})"
157
+ )
158
+ self.num_attention_heads = config.num_attention_heads
159
+ _attention_head_size = config.hidden_size // config.num_attention_heads
160
+ self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size)
161
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
162
+ self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
163
+ self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
164
+ self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
165
+
166
+ self.share_att_key = getattr(config, "share_att_key", False)
167
+ self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
168
+ self.relative_attention = getattr(config, "relative_attention", False)
169
+
170
+ if self.relative_attention:
171
+ self.position_buckets = getattr(config, "position_buckets", -1)
172
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
173
+ if self.max_relative_positions < 1:
174
+ self.max_relative_positions = config.max_position_embeddings
175
+ self.pos_ebd_size = self.max_relative_positions
176
+ if self.position_buckets > 0:
177
+ self.pos_ebd_size = self.position_buckets
178
+
179
+ self.pos_dropout = nn.Dropout(config.hidden_dropout_prob)
180
+
181
+ if not self.share_att_key:
182
+ if "c2p" in self.pos_att_type:
183
+ self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
184
+ if "p2c" in self.pos_att_type:
185
+ self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)
186
+
187
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
188
+
189
+ def transpose_for_scores(self, x, attention_heads) -> torch.Tensor:
190
+ new_x_shape = x.size()[:-1] + (attention_heads, -1)
191
+ x = x.view(new_x_shape)
192
+ return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
193
+
194
+ def forward(
195
+ self,
196
+ hidden_states,
197
+ attention_mask,
198
+ output_attentions=False,
199
+ query_states=None,
200
+ relative_pos=None,
201
+ rel_embeddings=None,
202
+ ):
203
+ """
204
+ Call the module
205
+
206
+ Args:
207
+ hidden_states (`torch.FloatTensor`):
208
+ Input states to the module usually the output from previous layer, it will be the Q,K and V in
209
+ *Attention(Q,K,V)*
210
+
211
+ attention_mask (`torch.BoolTensor`):
212
+ An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
213
+ sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
214
+ th token.
215
+
216
+ output_attentions (`bool`, *optional*):
217
+ Whether return the attention matrix.
218
+
219
+ query_states (`torch.FloatTensor`, *optional*):
220
+ The *Q* state in *Attention(Q,K,V)*.
221
+
222
+ relative_pos (`torch.LongTensor`):
223
+ The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
224
+ values ranging in [*-max_relative_positions*, *max_relative_positions*].
225
+
226
+ rel_embeddings (`torch.FloatTensor`):
227
+ The embedding of relative distances. It's a tensor of shape [\\(2 \\times
228
+ \\text{max_relative_positions}\\), *hidden_size*].
229
+
230
+
231
+ """
232
+ if query_states is None:
233
+ query_states = hidden_states
234
+ query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads)
235
+ key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads)
236
+ value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
237
+
238
+ rel_att = None
239
+ # Take the dot product between "query" and "key" to get the raw attention scores.
240
+ scale_factor = 1
241
+ if "c2p" in self.pos_att_type:
242
+ scale_factor += 1
243
+ if "p2c" in self.pos_att_type:
244
+ scale_factor += 1
245
+ scale = scaled_size_sqrt(query_layer, scale_factor)
246
+ attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype=query_layer.dtype))
247
+ if self.relative_attention:
248
+ rel_embeddings = self.pos_dropout(rel_embeddings)
249
+ rel_att = self.disentangled_attention_bias(
250
+ query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
251
+ )
252
+
253
+ if rel_att is not None:
254
+ attention_scores = attention_scores + rel_att
255
+ attention_scores = attention_scores.view(
256
+ -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
257
+ )
258
+
259
+ attention_mask = attention_mask.bool()
260
+ attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
261
+ # bsz x height x length x dimension
262
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
263
+
264
+ attention_probs = self.dropout(attention_probs)
265
+ context_layer = torch.bmm(
266
+ attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer
267
+ )
268
+ context_layer = (
269
+ context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1))
270
+ .permute(0, 2, 1, 3)
271
+ .contiguous()
272
+ )
273
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
274
+ context_layer = context_layer.view(new_context_layer_shape)
275
+ if not output_attentions:
276
+ return (context_layer, None)
277
+ return (context_layer, attention_probs)
278
+
279
+ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
280
+ if relative_pos is None:
281
+ relative_pos = build_relative_position(
282
+ query_layer,
283
+ key_layer,
284
+ bucket_size=self.position_buckets,
285
+ max_position=self.max_relative_positions,
286
+ )
287
+ if relative_pos.dim() == 2:
288
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
289
+ elif relative_pos.dim() == 3:
290
+ relative_pos = relative_pos.unsqueeze(1)
291
+ # bsz x height x query x key
292
+ elif relative_pos.dim() != 4:
293
+ raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
294
+
295
+ att_span = self.pos_ebd_size
296
+ relative_pos = relative_pos.to(device=query_layer.device, dtype=torch.long)
297
+
298
+ rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
299
+ if self.share_att_key:
300
+ pos_query_layer = self.transpose_for_scores(
301
+ self.query_proj(rel_embeddings), self.num_attention_heads
302
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
303
+ pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(
304
+ query_layer.size(0) // self.num_attention_heads, 1, 1
305
+ )
306
+ else:
307
+ if "c2p" in self.pos_att_type:
308
+ pos_key_layer = self.transpose_for_scores(
309
+ self.pos_key_proj(rel_embeddings), self.num_attention_heads
310
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1)
311
+ if "p2c" in self.pos_att_type:
312
+ pos_query_layer = self.transpose_for_scores(
313
+ self.pos_query_proj(rel_embeddings), self.num_attention_heads
314
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) # .split(self.all_head_size, dim=-1)
315
+
316
+ score = 0
317
+ # content->position
318
+ if "c2p" in self.pos_att_type:
319
+ scale = scaled_size_sqrt(pos_key_layer, scale_factor)
320
+ c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
321
+ c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
322
+ c2p_att = torch.gather(
323
+ c2p_att,
324
+ dim=-1,
325
+ index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
326
+ )
327
+ score += c2p_att / scale.to(dtype=c2p_att.dtype)
328
+
329
+ # position->content
330
+ if "p2c" in self.pos_att_type:
331
+ scale = scaled_size_sqrt(pos_query_layer, scale_factor)
332
+ r_pos = build_rpos(
333
+ query_layer,
334
+ key_layer,
335
+ relative_pos,
336
+ self.max_relative_positions,
337
+ self.position_buckets,
338
+ )
339
+ p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
340
+ p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
341
+ p2c_att = torch.gather(
342
+ p2c_att,
343
+ dim=-1,
344
+ index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
345
+ ).transpose(-1, -2)
346
+ score += p2c_att / scale.to(dtype=p2c_att.dtype)
347
+
348
+ return score
349
+
350
+
351
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2
352
+ class DebertaV2Attention(nn.Module):
353
+ def __init__(self, config):
354
+ super().__init__()
355
+ self.self = DisentangledSelfAttention(config)
356
+ self.output = DebertaV2SelfOutput(config)
357
+ self.config = config
358
+
359
+ def forward(
360
+ self,
361
+ hidden_states,
362
+ attention_mask,
363
+ output_attentions: bool = False,
364
+ query_states=None,
365
+ relative_pos=None,
366
+ rel_embeddings=None,
367
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
368
+ self_output, att_matrix = self.self(
369
+ hidden_states,
370
+ attention_mask,
371
+ output_attentions,
372
+ query_states=query_states,
373
+ relative_pos=relative_pos,
374
+ rel_embeddings=rel_embeddings,
375
+ )
376
+ if query_states is None:
377
+ query_states = hidden_states
378
+ attention_output = self.output(self_output, query_states)
379
+
380
+ if output_attentions:
381
+ return (attention_output, att_matrix)
382
+ else:
383
+ return (attention_output, None)
384
+
385
+
386
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2
387
+ class DebertaV2Intermediate(nn.Module):
388
+ def __init__(self, config):
389
+ super().__init__()
390
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
391
+ if isinstance(config.hidden_act, str):
392
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
393
+ else:
394
+ self.intermediate_act_fn = config.hidden_act
395
+
396
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
397
+ hidden_states = self.dense(hidden_states)
398
+ hidden_states = self.intermediate_act_fn(hidden_states)
399
+ return hidden_states
400
+
401
+
402
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm
403
+ class DebertaV2Output(nn.Module):
404
+ def __init__(self, config):
405
+ super().__init__()
406
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
407
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
408
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
409
+ self.config = config
410
+
411
+ def forward(self, hidden_states, input_tensor):
412
+ hidden_states = self.dense(hidden_states)
413
+ hidden_states = self.dropout(hidden_states)
414
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
415
+ return hidden_states
416
+
417
+
418
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
419
+ class DebertaV2Layer(GradientCheckpointingLayer):
420
+ def __init__(self, config):
421
+ super().__init__()
422
+ self.attention = DebertaV2Attention(config)
423
+ self.intermediate = DebertaV2Intermediate(config)
424
+ self.output = DebertaV2Output(config)
425
+
426
+ def forward(
427
+ self,
428
+ hidden_states,
429
+ attention_mask,
430
+ query_states=None,
431
+ relative_pos=None,
432
+ rel_embeddings=None,
433
+ output_attentions: bool = False,
434
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
435
+ attention_output, att_matrix = self.attention(
436
+ hidden_states,
437
+ attention_mask,
438
+ output_attentions=output_attentions,
439
+ query_states=query_states,
440
+ relative_pos=relative_pos,
441
+ rel_embeddings=rel_embeddings,
442
+ )
443
+ intermediate_output = self.intermediate(attention_output)
444
+ layer_output = self.output(intermediate_output, attention_output)
445
+
446
+ if output_attentions:
447
+ return (layer_output, att_matrix)
448
+ else:
449
+ return (layer_output, None)
450
+
451
+
452
+ class ConvLayer(nn.Module):
453
+ def __init__(self, config):
454
+ super().__init__()
455
+ kernel_size = getattr(config, "conv_kernel_size", 3)
456
+ groups = getattr(config, "conv_groups", 1)
457
+ self.conv_act = getattr(config, "conv_act", "tanh")
458
+ self.conv = nn.Conv1d(
459
+ config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups
460
+ )
461
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
462
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
463
+ self.config = config
464
+
465
+ def forward(self, hidden_states, residual_states, input_mask):
466
+ out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
467
+ rmask = (1 - input_mask).bool()
468
+ out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0)
469
+ out = ACT2FN[self.conv_act](self.dropout(out))
470
+
471
+ layer_norm_input = residual_states + out
472
+ output = self.LayerNorm(layer_norm_input).to(layer_norm_input)
473
+
474
+ if input_mask is None:
475
+ output_states = output
476
+ else:
477
+ if input_mask.dim() != layer_norm_input.dim():
478
+ if input_mask.dim() == 4:
479
+ input_mask = input_mask.squeeze(1).squeeze(1)
480
+ input_mask = input_mask.unsqueeze(2)
481
+
482
+ input_mask = input_mask.to(output.dtype)
483
+ output_states = output * input_mask
484
+
485
+ return output_states
486
+
487
+
488
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm,Deberta->DebertaV2
489
+ class DebertaV2Embeddings(nn.Module):
490
+ """Construct the embeddings from word, position and token_type embeddings."""
491
+
492
+ def __init__(self, config):
493
+ super().__init__()
494
+ pad_token_id = getattr(config, "pad_token_id", 0)
495
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
496
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id)
497
+
498
+ self.position_biased_input = getattr(config, "position_biased_input", True)
499
+ if not self.position_biased_input:
500
+ self.position_embeddings = None
501
+ else:
502
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
503
+
504
+ if config.type_vocab_size > 0:
505
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
506
+ else:
507
+ self.token_type_embeddings = None
508
+
509
+ if self.embedding_size != config.hidden_size:
510
+ self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
511
+ else:
512
+ self.embed_proj = None
513
+
514
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
515
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
516
+ self.config = config
517
+
518
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
519
+ self.register_buffer(
520
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
521
+ )
522
+
523
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None):
524
+ if input_ids is not None:
525
+ input_shape = input_ids.size()
526
+ else:
527
+ input_shape = inputs_embeds.size()[:-1]
528
+
529
+ seq_length = input_shape[1]
530
+
531
+ if position_ids is None:
532
+ position_ids = self.position_ids[:, :seq_length]
533
+
534
+ if token_type_ids is None:
535
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
536
+
537
+ if inputs_embeds is None:
538
+ inputs_embeds = self.word_embeddings(input_ids)
539
+
540
+ if self.position_embeddings is not None:
541
+ position_embeddings = self.position_embeddings(position_ids.long())
542
+ else:
543
+ position_embeddings = torch.zeros_like(inputs_embeds)
544
+
545
+ embeddings = inputs_embeds
546
+ if self.position_biased_input:
547
+ embeddings = embeddings + position_embeddings
548
+ if self.token_type_embeddings is not None:
549
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
550
+ embeddings = embeddings + token_type_embeddings
551
+
552
+ if self.embed_proj is not None:
553
+ embeddings = self.embed_proj(embeddings)
554
+
555
+ embeddings = self.LayerNorm(embeddings)
556
+
557
+ if mask is not None:
558
+ if mask.dim() != embeddings.dim():
559
+ if mask.dim() == 4:
560
+ mask = mask.squeeze(1).squeeze(1)
561
+ mask = mask.unsqueeze(2)
562
+ mask = mask.to(embeddings.dtype)
563
+
564
+ embeddings = embeddings * mask
565
+
566
+ embeddings = self.dropout(embeddings)
567
+ return embeddings
568
+
569
+
570
+ class DebertaV2Encoder(nn.Module):
571
+ """Modified BertEncoder with relative position bias support"""
572
+
573
+ def __init__(self, config):
574
+ super().__init__()
575
+
576
+ self.layer = nn.ModuleList([DebertaV2Layer(config) for _ in range(config.num_hidden_layers)])
577
+ self.relative_attention = getattr(config, "relative_attention", False)
578
+
579
+ if self.relative_attention:
580
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
581
+ if self.max_relative_positions < 1:
582
+ self.max_relative_positions = config.max_position_embeddings
583
+
584
+ self.position_buckets = getattr(config, "position_buckets", -1)
585
+ pos_ebd_size = self.max_relative_positions * 2
586
+
587
+ if self.position_buckets > 0:
588
+ pos_ebd_size = self.position_buckets * 2
589
+
590
+ self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size)
591
+
592
+ self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")]
593
+
594
+ if "layer_norm" in self.norm_rel_ebd:
595
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
596
+
597
+ self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None
598
+ self.gradient_checkpointing = False
599
+
600
+ def get_rel_embedding(self):
601
+ rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
602
+ if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd):
603
+ rel_embeddings = self.LayerNorm(rel_embeddings)
604
+ return rel_embeddings
605
+
606
+ def get_attention_mask(self, attention_mask):
607
+ if attention_mask.dim() <= 2:
608
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
609
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
610
+ elif attention_mask.dim() == 3:
611
+ attention_mask = attention_mask.unsqueeze(1)
612
+
613
+ return attention_mask
614
+
615
+ def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
616
+ if self.relative_attention and relative_pos is None:
617
+ if query_states is not None:
618
+ relative_pos = build_relative_position(
619
+ query_states,
620
+ hidden_states,
621
+ bucket_size=self.position_buckets,
622
+ max_position=self.max_relative_positions,
623
+ )
624
+ else:
625
+ relative_pos = build_relative_position(
626
+ hidden_states,
627
+ hidden_states,
628
+ bucket_size=self.position_buckets,
629
+ max_position=self.max_relative_positions,
630
+ )
631
+ return relative_pos
632
+
633
+ def forward(
634
+ self,
635
+ hidden_states,
636
+ attention_mask,
637
+ output_hidden_states=True,
638
+ output_attentions=False,
639
+ query_states=None,
640
+ relative_pos=None,
641
+ return_dict=True,
642
+ ):
643
+ if attention_mask.dim() <= 2:
644
+ input_mask = attention_mask
645
+ else:
646
+ input_mask = attention_mask.sum(-2) > 0
647
+ attention_mask = self.get_attention_mask(attention_mask)
648
+ relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
649
+
650
+ all_hidden_states: tuple[torch.Tensor] | None = (hidden_states,) if output_hidden_states else None
651
+ all_attentions = () if output_attentions else None
652
+
653
+ next_kv = hidden_states
654
+ rel_embeddings = self.get_rel_embedding()
655
+ for i, layer_module in enumerate(self.layer):
656
+ output_states, attn_weights = layer_module(
657
+ next_kv,
658
+ attention_mask,
659
+ query_states=query_states,
660
+ relative_pos=relative_pos,
661
+ rel_embeddings=rel_embeddings,
662
+ output_attentions=output_attentions,
663
+ )
664
+
665
+ if output_attentions:
666
+ all_attentions = all_attentions + (attn_weights,)
667
+
668
+ if i == 0 and self.conv is not None:
669
+ output_states = self.conv(hidden_states, output_states, input_mask)
670
+
671
+ if output_hidden_states:
672
+ all_hidden_states = all_hidden_states + (output_states,)
673
+
674
+ if query_states is not None:
675
+ query_states = output_states
676
+ if isinstance(hidden_states, Sequence):
677
+ next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None
678
+ else:
679
+ next_kv = output_states
680
+
681
+ if not return_dict:
682
+ return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
683
+ return BaseModelOutput(
684
+ last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions
685
+ )
686
+
687
+
688
+ @auto_docstring
689
+ class DebertaV2PreTrainedModel(PreTrainedModel):
690
+ config: DebertaV2Config
691
+ base_model_prefix = "deberta"
692
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
693
+ supports_gradient_checkpointing = True
694
+
695
+ @torch.no_grad()
696
+ def _init_weights(self, module):
697
+ """Initialize the weights."""
698
+ super()._init_weights(module)
699
+ if isinstance(module, (LegacyDebertaV2LMPredictionHead, DebertaV2LMPredictionHead)):
700
+ init.zeros_(module.bias)
701
+ elif isinstance(module, DebertaV2Embeddings):
702
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
703
+
704
+
705
+ @auto_docstring
706
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2
707
+ class DebertaV2Model(DebertaV2PreTrainedModel):
708
+ def __init__(self, config):
709
+ super().__init__(config)
710
+
711
+ self.embeddings = DebertaV2Embeddings(config)
712
+ self.encoder = DebertaV2Encoder(config)
713
+ self.z_steps = 0
714
+ self.config = config
715
+ # Initialize weights and apply final processing
716
+ self.post_init()
717
+
718
+ def get_input_embeddings(self):
719
+ return self.embeddings.word_embeddings
720
+
721
+ def set_input_embeddings(self, new_embeddings):
722
+ self.embeddings.word_embeddings = new_embeddings
723
+
724
+ @auto_docstring
725
+ def forward(
726
+ self,
727
+ input_ids: torch.Tensor | None = None,
728
+ attention_mask: torch.Tensor | None = None,
729
+ token_type_ids: torch.Tensor | None = None,
730
+ position_ids: torch.Tensor | None = None,
731
+ inputs_embeds: torch.Tensor | None = None,
732
+ output_attentions: bool | None = None,
733
+ output_hidden_states: bool | None = None,
734
+ return_dict: bool | None = None,
735
+ **kwargs,
736
+ ) -> tuple | BaseModelOutput:
737
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
738
+ output_hidden_states = (
739
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
740
+ )
741
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
742
+
743
+ if input_ids is not None and inputs_embeds is not None:
744
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
745
+ elif input_ids is not None:
746
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
747
+ input_shape = input_ids.size()
748
+ elif inputs_embeds is not None:
749
+ input_shape = inputs_embeds.size()[:-1]
750
+ else:
751
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
752
+
753
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
754
+
755
+ if attention_mask is None:
756
+ attention_mask = torch.ones(input_shape, device=device)
757
+ if token_type_ids is None:
758
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
759
+
760
+ embedding_output = self.embeddings(
761
+ input_ids=input_ids,
762
+ token_type_ids=token_type_ids,
763
+ position_ids=position_ids,
764
+ mask=attention_mask,
765
+ inputs_embeds=inputs_embeds,
766
+ )
767
+
768
+ encoder_outputs = self.encoder(
769
+ embedding_output,
770
+ attention_mask,
771
+ output_hidden_states=True,
772
+ output_attentions=output_attentions,
773
+ return_dict=return_dict,
774
+ )
775
+ encoded_layers = encoder_outputs[1]
776
+
777
+ if self.z_steps > 1:
778
+ hidden_states = encoded_layers[-2]
779
+ layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
780
+ query_states = encoded_layers[-1]
781
+ rel_embeddings = self.encoder.get_rel_embedding()
782
+ attention_mask = self.encoder.get_attention_mask(attention_mask)
783
+ rel_pos = self.encoder.get_rel_pos(embedding_output)
784
+ for layer in layers[1:]:
785
+ query_states = layer(
786
+ hidden_states,
787
+ attention_mask,
788
+ output_attentions=False,
789
+ query_states=query_states,
790
+ relative_pos=rel_pos,
791
+ rel_embeddings=rel_embeddings,
792
+ )
793
+ encoded_layers.append(query_states)
794
+
795
+ sequence_output = encoded_layers[-1]
796
+
797
+ if not return_dict:
798
+ return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :]
799
+
800
+ return BaseModelOutput(
801
+ last_hidden_state=sequence_output,
802
+ hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
803
+ attentions=encoder_outputs.attentions,
804
+ )
805
+
806
+
807
+ # Copied from transformers.models.deberta.modeling_deberta.LegacyDebertaPredictionHeadTransform with Deberta->DebertaV2
808
+ class LegacyDebertaV2PredictionHeadTransform(nn.Module):
809
+ def __init__(self, config):
810
+ super().__init__()
811
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
812
+
813
+ self.dense = nn.Linear(config.hidden_size, self.embedding_size)
814
+ if isinstance(config.hidden_act, str):
815
+ self.transform_act_fn = ACT2FN[config.hidden_act]
816
+ else:
817
+ self.transform_act_fn = config.hidden_act
818
+ self.LayerNorm = nn.LayerNorm(self.embedding_size, eps=config.layer_norm_eps)
819
+
820
+ def forward(self, hidden_states):
821
+ hidden_states = self.dense(hidden_states)
822
+ hidden_states = self.transform_act_fn(hidden_states)
823
+ hidden_states = self.LayerNorm(hidden_states)
824
+ return hidden_states
825
+
826
+
827
+ class LegacyDebertaV2LMPredictionHead(nn.Module):
828
+ def __init__(self, config):
829
+ super().__init__()
830
+ self.transform = LegacyDebertaV2PredictionHeadTransform(config)
831
+
832
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
833
+ # The output weights are the same as the input embeddings, but there is
834
+ # an output-only bias for each token.
835
+ self.decoder = nn.Linear(self.embedding_size, config.vocab_size)
836
+
837
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
838
+
839
+ def forward(self, hidden_states):
840
+ hidden_states = self.transform(hidden_states)
841
+ hidden_states = self.decoder(hidden_states)
842
+ return hidden_states
843
+
844
+
845
+ class LegacyDebertaV2OnlyMLMHead(nn.Module):
846
+ def __init__(self, config):
847
+ super().__init__()
848
+ self.predictions = LegacyDebertaV2LMPredictionHead(config)
849
+
850
+ def forward(self, sequence_output):
851
+ prediction_scores = self.predictions(sequence_output)
852
+ return prediction_scores
853
+
854
+
855
+ class DebertaV2LMPredictionHead(nn.Module):
856
+ """https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/bert.py#L270"""
857
+
858
+ def __init__(self, config):
859
+ super().__init__()
860
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
861
+
862
+ if isinstance(config.hidden_act, str):
863
+ self.transform_act_fn = ACT2FN[config.hidden_act]
864
+ else:
865
+ self.transform_act_fn = config.hidden_act
866
+
867
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=True)
868
+
869
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
870
+
871
+ # note that the input embeddings must be passed as an argument
872
+ def forward(self, hidden_states, word_embeddings):
873
+ hidden_states = self.dense(hidden_states)
874
+ hidden_states = self.transform_act_fn(hidden_states)
875
+ hidden_states = self.LayerNorm(hidden_states)
876
+ hidden_states = torch.matmul(hidden_states, word_embeddings.weight.t()) + self.bias
877
+ return hidden_states
878
+
879
+
880
+ class DebertaV2OnlyMLMHead(nn.Module):
881
+ def __init__(self, config):
882
+ super().__init__()
883
+ self.lm_head = DebertaV2LMPredictionHead(config)
884
+
885
+ # note that the input embeddings must be passed as an argument
886
+ def forward(self, sequence_output, word_embeddings):
887
+ prediction_scores = self.lm_head(sequence_output, word_embeddings)
888
+ return prediction_scores
889
+
890
+
891
+ @auto_docstring
892
+ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
893
+ _tied_weights_keys = {
894
+ "cls.predictions.decoder.bias": "cls.predictions.bias",
895
+ "cls.predictions.decoder.weight": "deberta.embeddings.word_embeddings.weight",
896
+ }
897
+ _keys_to_ignore_on_load_unexpected = [r"mask_predictions.*"]
898
+
899
+ def __init__(self, config):
900
+ super().__init__(config)
901
+ self.legacy = config.legacy
902
+ self.deberta = DebertaV2Model(config)
903
+ if self.legacy:
904
+ self.cls = LegacyDebertaV2OnlyMLMHead(config)
905
+ else:
906
+ self._tied_weights_keys = {
907
+ "lm_predictions.lm_head.weight": "deberta.embeddings.word_embeddings.weight",
908
+ }
909
+ self.lm_predictions = DebertaV2OnlyMLMHead(config)
910
+ # Initialize weights and apply final processing
911
+ self.post_init()
912
+
913
+ def get_output_embeddings(self):
914
+ if self.legacy:
915
+ return self.cls.predictions.decoder
916
+ else:
917
+ return self.lm_predictions.lm_head.dense
918
+
919
+ def set_output_embeddings(self, new_embeddings):
920
+ if self.legacy:
921
+ self.cls.predictions.decoder = new_embeddings
922
+ self.cls.predictions.bias = new_embeddings.bias
923
+ else:
924
+ self.lm_predictions.lm_head.dense = new_embeddings
925
+ self.lm_predictions.lm_head.bias = new_embeddings.bias
926
+
927
+ @auto_docstring
928
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM.forward with Deberta->DebertaV2
929
+ def forward(
930
+ self,
931
+ input_ids: torch.Tensor | None = None,
932
+ attention_mask: torch.Tensor | None = None,
933
+ token_type_ids: torch.Tensor | None = None,
934
+ position_ids: torch.Tensor | None = None,
935
+ inputs_embeds: torch.Tensor | None = None,
936
+ labels: torch.Tensor | None = None,
937
+ output_attentions: bool | None = None,
938
+ output_hidden_states: bool | None = None,
939
+ return_dict: bool | None = None,
940
+ **kwargs,
941
+ ) -> tuple | MaskedLMOutput:
942
+ r"""
943
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
944
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
945
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
946
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
947
+ """
948
+
949
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
950
+
951
+ outputs = self.deberta(
952
+ input_ids,
953
+ attention_mask=attention_mask,
954
+ token_type_ids=token_type_ids,
955
+ position_ids=position_ids,
956
+ inputs_embeds=inputs_embeds,
957
+ output_attentions=output_attentions,
958
+ output_hidden_states=output_hidden_states,
959
+ return_dict=return_dict,
960
+ )
961
+
962
+ sequence_output = outputs[0]
963
+ if self.legacy:
964
+ prediction_scores = self.cls(sequence_output)
965
+ else:
966
+ prediction_scores = self.lm_predictions(sequence_output, self.deberta.embeddings.word_embeddings)
967
+
968
+ masked_lm_loss = None
969
+ if labels is not None:
970
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
971
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
972
+
973
+ if not return_dict:
974
+ output = (prediction_scores,) + outputs[1:]
975
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
976
+
977
+ return MaskedLMOutput(
978
+ loss=masked_lm_loss,
979
+ logits=prediction_scores,
980
+ hidden_states=outputs.hidden_states,
981
+ attentions=outputs.attentions,
982
+ )
983
+
984
+
985
+ # Copied from transformers.models.deberta.modeling_deberta.ContextPooler
986
+ class ContextPooler(nn.Module):
987
+ def __init__(self, config):
988
+ super().__init__()
989
+ self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
990
+ self.dropout = nn.Dropout(config.pooler_dropout)
991
+ self.config = config
992
+
993
+ def forward(self, hidden_states):
994
+ # We "pool" the model by simply taking the hidden state corresponding
995
+ # to the first token.
996
+
997
+ context_token = hidden_states[:, 0]
998
+ context_token = self.dropout(context_token)
999
+ pooled_output = self.dense(context_token)
1000
+ pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
1001
+ return pooled_output
1002
+
1003
+ @property
1004
+ def output_dim(self):
1005
+ return self.config.hidden_size
1006
+
1007
+
1008
+ @auto_docstring(
1009
+ custom_intro="""
1010
+ DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1011
+ pooled output) e.g. for GLUE tasks.
1012
+ """
1013
+ )
1014
+ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
1015
+ def __init__(self, config):
1016
+ super().__init__(config)
1017
+
1018
+ num_labels = getattr(config, "num_labels", 2)
1019
+ self.num_labels = num_labels
1020
+
1021
+ self.deberta = DebertaV2Model(config)
1022
+ self.pooler = ContextPooler(config)
1023
+ output_dim = self.pooler.output_dim
1024
+
1025
+ self.classifier = nn.Linear(output_dim, num_labels)
1026
+ drop_out = getattr(config, "cls_dropout", None)
1027
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1028
+ self.dropout = nn.Dropout(drop_out)
1029
+
1030
+ # Initialize weights and apply final processing
1031
+ self.post_init()
1032
+
1033
+ def get_input_embeddings(self):
1034
+ return self.deberta.get_input_embeddings()
1035
+
1036
+ def set_input_embeddings(self, new_embeddings):
1037
+ self.deberta.set_input_embeddings(new_embeddings)
1038
+
1039
+ @auto_docstring
1040
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification.forward with Deberta->DebertaV2
1041
+ def forward(
1042
+ self,
1043
+ input_ids: torch.Tensor | None = None,
1044
+ attention_mask: torch.Tensor | None = None,
1045
+ token_type_ids: torch.Tensor | None = None,
1046
+ position_ids: torch.Tensor | None = None,
1047
+ inputs_embeds: torch.Tensor | None = None,
1048
+ labels: torch.Tensor | None = None,
1049
+ output_attentions: bool | None = None,
1050
+ output_hidden_states: bool | None = None,
1051
+ return_dict: bool | None = None,
1052
+ **kwargs,
1053
+ ) -> tuple | SequenceClassifierOutput:
1054
+ r"""
1055
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1056
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1057
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1058
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1059
+ """
1060
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1061
+
1062
+ outputs = self.deberta(
1063
+ input_ids,
1064
+ token_type_ids=token_type_ids,
1065
+ attention_mask=attention_mask,
1066
+ position_ids=position_ids,
1067
+ inputs_embeds=inputs_embeds,
1068
+ output_attentions=output_attentions,
1069
+ output_hidden_states=output_hidden_states,
1070
+ return_dict=return_dict,
1071
+ )
1072
+
1073
+ encoder_layer = outputs[0]
1074
+ pooled_output = self.pooler(encoder_layer)
1075
+ pooled_output = self.dropout(pooled_output)
1076
+ logits = self.classifier(pooled_output)
1077
+
1078
+ loss = None
1079
+ if labels is not None:
1080
+ if self.config.problem_type is None:
1081
+ if self.num_labels == 1:
1082
+ # regression task
1083
+ loss_fn = nn.MSELoss()
1084
+ logits = logits.view(-1).to(labels.dtype)
1085
+ loss = loss_fn(logits, labels.view(-1))
1086
+ elif labels.dim() == 1 or labels.size(-1) == 1:
1087
+ label_index = (labels >= 0).nonzero()
1088
+ labels = labels.long()
1089
+ if label_index.size(0) > 0:
1090
+ labeled_logits = torch.gather(
1091
+ logits, 0, label_index.expand(label_index.size(0), logits.size(1))
1092
+ )
1093
+ labels = torch.gather(labels, 0, label_index.view(-1))
1094
+ loss_fct = CrossEntropyLoss()
1095
+ loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
1096
+ else:
1097
+ loss = torch.tensor(0).to(logits)
1098
+ else:
1099
+ log_softmax = nn.LogSoftmax(-1)
1100
+ loss = -((log_softmax(logits) * labels).sum(-1)).mean()
1101
+ elif self.config.problem_type == "regression":
1102
+ loss_fct = MSELoss()
1103
+ if self.num_labels == 1:
1104
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1105
+ else:
1106
+ loss = loss_fct(logits, labels)
1107
+ elif self.config.problem_type == "single_label_classification":
1108
+ loss_fct = CrossEntropyLoss()
1109
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1110
+ elif self.config.problem_type == "multi_label_classification":
1111
+ loss_fct = BCEWithLogitsLoss()
1112
+ loss = loss_fct(logits, labels)
1113
+ if not return_dict:
1114
+ output = (logits,) + outputs[1:]
1115
+ return ((loss,) + output) if loss is not None else output
1116
+
1117
+ return SequenceClassifierOutput(
1118
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1119
+ )
1120
+
1121
+
1122
+ @auto_docstring
1123
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2
1124
+ class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
1125
+ def __init__(self, config):
1126
+ super().__init__(config)
1127
+ self.num_labels = config.num_labels
1128
+
1129
+ self.deberta = DebertaV2Model(config)
1130
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1131
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1132
+
1133
+ # Initialize weights and apply final processing
1134
+ self.post_init()
1135
+
1136
+ @auto_docstring
1137
+ def forward(
1138
+ self,
1139
+ input_ids: torch.Tensor | None = None,
1140
+ attention_mask: torch.Tensor | None = None,
1141
+ token_type_ids: torch.Tensor | None = None,
1142
+ position_ids: torch.Tensor | None = None,
1143
+ inputs_embeds: torch.Tensor | None = None,
1144
+ labels: torch.Tensor | None = None,
1145
+ output_attentions: bool | None = None,
1146
+ output_hidden_states: bool | None = None,
1147
+ return_dict: bool | None = None,
1148
+ **kwargs,
1149
+ ) -> tuple | TokenClassifierOutput:
1150
+ r"""
1151
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1152
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1153
+ """
1154
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1155
+
1156
+ outputs = self.deberta(
1157
+ input_ids,
1158
+ attention_mask=attention_mask,
1159
+ token_type_ids=token_type_ids,
1160
+ position_ids=position_ids,
1161
+ inputs_embeds=inputs_embeds,
1162
+ output_attentions=output_attentions,
1163
+ output_hidden_states=output_hidden_states,
1164
+ return_dict=return_dict,
1165
+ )
1166
+
1167
+ sequence_output = outputs[0]
1168
+
1169
+ sequence_output = self.dropout(sequence_output)
1170
+ logits = self.classifier(sequence_output)
1171
+
1172
+ loss = None
1173
+ if labels is not None:
1174
+ loss_fct = CrossEntropyLoss()
1175
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1176
+
1177
+ if not return_dict:
1178
+ output = (logits,) + outputs[1:]
1179
+ return ((loss,) + output) if loss is not None else output
1180
+
1181
+ return TokenClassifierOutput(
1182
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1183
+ )
1184
+
1185
+
1186
+ @auto_docstring
1187
+ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
1188
+ def __init__(self, config):
1189
+ super().__init__(config)
1190
+ self.num_labels = config.num_labels
1191
+
1192
+ self.deberta = DebertaV2Model(config)
1193
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1194
+
1195
+ # Initialize weights and apply final processing
1196
+ self.post_init()
1197
+
1198
+ @auto_docstring
1199
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering.forward with Deberta->DebertaV2
1200
+ def forward(
1201
+ self,
1202
+ input_ids: torch.Tensor | None = None,
1203
+ attention_mask: torch.Tensor | None = None,
1204
+ token_type_ids: torch.Tensor | None = None,
1205
+ position_ids: torch.Tensor | None = None,
1206
+ inputs_embeds: torch.Tensor | None = None,
1207
+ start_positions: torch.Tensor | None = None,
1208
+ end_positions: torch.Tensor | None = None,
1209
+ output_attentions: bool | None = None,
1210
+ output_hidden_states: bool | None = None,
1211
+ return_dict: bool | None = None,
1212
+ **kwargs,
1213
+ ) -> tuple | QuestionAnsweringModelOutput:
1214
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1215
+
1216
+ outputs = self.deberta(
1217
+ input_ids,
1218
+ attention_mask=attention_mask,
1219
+ token_type_ids=token_type_ids,
1220
+ position_ids=position_ids,
1221
+ inputs_embeds=inputs_embeds,
1222
+ output_attentions=output_attentions,
1223
+ output_hidden_states=output_hidden_states,
1224
+ return_dict=return_dict,
1225
+ )
1226
+
1227
+ sequence_output = outputs[0]
1228
+
1229
+ logits = self.qa_outputs(sequence_output)
1230
+ start_logits, end_logits = logits.split(1, dim=-1)
1231
+ start_logits = start_logits.squeeze(-1).contiguous()
1232
+ end_logits = end_logits.squeeze(-1).contiguous()
1233
+
1234
+ total_loss = None
1235
+ if start_positions is not None and end_positions is not None:
1236
+ # If we are on multi-GPU, split add a dimension
1237
+ if len(start_positions.size()) > 1:
1238
+ start_positions = start_positions.squeeze(-1)
1239
+ if len(end_positions.size()) > 1:
1240
+ end_positions = end_positions.squeeze(-1)
1241
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1242
+ ignored_index = start_logits.size(1)
1243
+ start_positions = start_positions.clamp(0, ignored_index)
1244
+ end_positions = end_positions.clamp(0, ignored_index)
1245
+
1246
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1247
+ start_loss = loss_fct(start_logits, start_positions)
1248
+ end_loss = loss_fct(end_logits, end_positions)
1249
+ total_loss = (start_loss + end_loss) / 2
1250
+
1251
+ if not return_dict:
1252
+ output = (start_logits, end_logits) + outputs[1:]
1253
+ return ((total_loss,) + output) if total_loss is not None else output
1254
+
1255
+ return QuestionAnsweringModelOutput(
1256
+ loss=total_loss,
1257
+ start_logits=start_logits,
1258
+ end_logits=end_logits,
1259
+ hidden_states=outputs.hidden_states,
1260
+ attentions=outputs.attentions,
1261
+ )
1262
+
1263
+
1264
+ @auto_docstring
1265
+ class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel):
1266
+ def __init__(self, config):
1267
+ super().__init__(config)
1268
+
1269
+ num_labels = getattr(config, "num_labels", 2)
1270
+ self.num_labels = num_labels
1271
+
1272
+ self.deberta = DebertaV2Model(config)
1273
+ self.pooler = ContextPooler(config)
1274
+ output_dim = self.pooler.output_dim
1275
+
1276
+ self.classifier = nn.Linear(output_dim, 1)
1277
+ drop_out = getattr(config, "cls_dropout", None)
1278
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1279
+ self.dropout = nn.Dropout(drop_out)
1280
+
1281
+ self.post_init()
1282
+
1283
+ def get_input_embeddings(self):
1284
+ return self.deberta.get_input_embeddings()
1285
+
1286
+ def set_input_embeddings(self, new_embeddings):
1287
+ self.deberta.set_input_embeddings(new_embeddings)
1288
+
1289
+ @auto_docstring
1290
+ def forward(
1291
+ self,
1292
+ input_ids: torch.Tensor | None = None,
1293
+ attention_mask: torch.Tensor | None = None,
1294
+ token_type_ids: torch.Tensor | None = None,
1295
+ position_ids: torch.Tensor | None = None,
1296
+ inputs_embeds: torch.Tensor | None = None,
1297
+ labels: torch.Tensor | None = None,
1298
+ output_attentions: bool | None = None,
1299
+ output_hidden_states: bool | None = None,
1300
+ return_dict: bool | None = None,
1301
+ **kwargs,
1302
+ ) -> tuple | MultipleChoiceModelOutput:
1303
+ r"""
1304
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1305
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1306
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1307
+ `input_ids` above)
1308
+ """
1309
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1310
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1311
+
1312
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1313
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1314
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1315
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1316
+ flat_inputs_embeds = (
1317
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1318
+ if inputs_embeds is not None
1319
+ else None
1320
+ )
1321
+
1322
+ outputs = self.deberta(
1323
+ flat_input_ids,
1324
+ position_ids=flat_position_ids,
1325
+ token_type_ids=flat_token_type_ids,
1326
+ attention_mask=flat_attention_mask,
1327
+ inputs_embeds=flat_inputs_embeds,
1328
+ output_attentions=output_attentions,
1329
+ output_hidden_states=output_hidden_states,
1330
+ return_dict=return_dict,
1331
+ )
1332
+
1333
+ encoder_layer = outputs[0]
1334
+ pooled_output = self.pooler(encoder_layer)
1335
+ pooled_output = self.dropout(pooled_output)
1336
+ logits = self.classifier(pooled_output)
1337
+ reshaped_logits = logits.view(-1, num_choices)
1338
+
1339
+ loss = None
1340
+ if labels is not None:
1341
+ loss_fct = CrossEntropyLoss()
1342
+ loss = loss_fct(reshaped_logits, labels)
1343
+
1344
+ if not return_dict:
1345
+ output = (reshaped_logits,) + outputs[1:]
1346
+ return ((loss,) + output) if loss is not None else output
1347
+
1348
+ return MultipleChoiceModelOutput(
1349
+ loss=loss,
1350
+ logits=reshaped_logits,
1351
+ hidden_states=outputs.hidden_states,
1352
+ attentions=outputs.attentions,
1353
+ )
1354
+
1355
+
1356
+ __all__ = [
1357
+ "DebertaV2ForMaskedLM",
1358
+ "DebertaV2ForMultipleChoice",
1359
+ "DebertaV2ForQuestionAnswering",
1360
+ "DebertaV2ForSequenceClassification",
1361
+ "DebertaV2ForTokenClassification",
1362
+ "DebertaV2Model",
1363
+ "DebertaV2PreTrainedModel",
1364
+ ]
tokenizer.json CHANGED
@@ -1046,33 +1046,6 @@
1046
  "rstrip": false,
1047
  "normalized": true,
1048
  "special": false
1049
- },
1050
- {
1051
- "id": 50368,
1052
- "content": "[START]",
1053
- "single_word": false,
1054
- "lstrip": false,
1055
- "rstrip": false,
1056
- "normalized": false,
1057
- "special": true
1058
- },
1059
- {
1060
- "id": 50369,
1061
- "content": "[END]",
1062
- "single_word": false,
1063
- "lstrip": false,
1064
- "rstrip": false,
1065
- "normalized": false,
1066
- "special": true
1067
- },
1068
- {
1069
- "id": 50370,
1070
- "content": "[GLOSS]",
1071
- "single_word": false,
1072
- "lstrip": false,
1073
- "rstrip": false,
1074
- "normalized": false,
1075
- "special": true
1076
  }
1077
  ],
1078
  "normalizer": {
 
1046
  "rstrip": false,
1047
  "normalized": true,
1048
  "special": false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1049
  }
1050
  ],
1051
  "normalizer": {
tokenizer_config.json CHANGED
@@ -2,11 +2,6 @@
2
  "backend": "tokenizers",
3
  "clean_up_tokenization_spaces": true,
4
  "cls_token": "[CLS]",
5
- "extra_special_tokens": [
6
- "[START]",
7
- "[END]",
8
- "[GLOSS]"
9
- ],
10
  "is_local": false,
11
  "mask_token": "[MASK]",
12
  "model_input_names": [
 
2
  "backend": "tokenizers",
3
  "clean_up_tokenization_spaces": true,
4
  "cls_token": "[CLS]",
 
 
 
 
 
5
  "is_local": false,
6
  "mask_token": "[MASK]",
7
  "model_input_names": [
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7d495bea733a9a1d58c23c5a0a1a14cb54e0ab61bd6b81fc5d20360188143cd1
3
  size 5265
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76ebe973e3dccc2be47ff0763275e5783a298a4f33403e7a15d8ea9e90eeb842
3
  size 5265