chtan commited on
Commit
8ff523b
·
1 Parent(s): d1e77d0

Upload PoNetForPreTraining (#1)

Browse files

- Upload PoNetForPreTraining (ff32cbdcfa90b14ca85c98f2134ccfb29b6ccf40)

Files changed (4) hide show
  1. config.json +9 -3
  2. configuration_ponet.py +149 -0
  3. modeling_ponet.py +1000 -0
  4. pytorch_model.bin +2 -2
config.json CHANGED
@@ -1,8 +1,14 @@
1
  {
2
- "_name_or_path": "ponet-base-uncased",
3
  "architectures": [
4
  "PoNetForPreTraining"
5
  ],
 
 
 
 
 
 
6
  "gradient_checkpointing": false,
7
  "hidden_act": "gelu",
8
  "hidden_dropout_prob": 0.1,
@@ -16,9 +22,9 @@
16
  "num_hidden_layers": 12,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "absolute",
19
- "transformers_version": "4.7.0",
 
20
  "type_vocab_size": 2,
21
  "use_cache": true,
22
- "clsgsepg": true,
23
  "vocab_size": 30522
24
  }
 
1
  {
2
+ "_name_or_path": "ponet",
3
  "architectures": [
4
  "PoNetForPreTraining"
5
  ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_ponet.PoNetConfig",
8
+ "AutoModelForPreTraining": "modeling_ponet.PoNetForPreTraining"
9
+ },
10
+ "classifier_dropout": null,
11
+ "clsgsepg": true,
12
  "gradient_checkpointing": false,
13
  "hidden_act": "gelu",
14
  "hidden_dropout_prob": 0.1,
 
22
  "num_hidden_layers": 12,
23
  "pad_token_id": 0,
24
  "position_embedding_type": "absolute",
25
+ "torch_dtype": "float32",
26
+ "transformers_version": "4.28.0.dev0",
27
  "type_vocab_size": 2,
28
  "use_cache": true,
 
29
  "vocab_size": 30522
30
  }
configuration_ponet.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ PONET model configuration"""
17
+ from collections import OrderedDict
18
+ from typing import Mapping
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.onnx import OnnxConfig
22
+ from transformers.utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ PONET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
28
+ "chtan/ponet-base-uncased": "https://huggingface.co/chtan/ponet-base-uncased/resolve/main/config.json",
29
+ }
30
+
31
+
32
+ class PoNetConfig(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`PoNetModel`] or a [`TFPoNetModel`]. It is used to
35
+ instantiate a PONET model according to the specified arguments, defining the model architecture. Instantiating a
36
+ configuration with the defaults will yield a similar configuration to that of the PONET
37
+ [chtan/ponet-base-uncased](https://huggingface.co/chtan/ponet-base-uncased) architecture.
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+
43
+ Args:
44
+ vocab_size (`int`, *optional*, defaults to 30522):
45
+ Vocabulary size of the PONET model. Defines the number of different tokens that can be represented by the
46
+ `inputs_ids` passed when calling [`PoNetModel`] or [`TFPoNetModel`].
47
+ hidden_size (`int`, *optional*, defaults to 768):
48
+ Dimensionality of the encoder layers and the pooler layer.
49
+ num_hidden_layers (`int`, *optional*, defaults to 12):
50
+ Number of hidden layers in the Transformer encoder.
51
+ num_attention_heads (`int`, *optional*, defaults to 12):
52
+ Number of attention heads for each attention layer in the Transformer encoder.
53
+ intermediate_size (`int`, *optional*, defaults to 3072):
54
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
55
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
56
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
57
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
58
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
59
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
60
+ max_position_embeddings (`int`, *optional*, defaults to 512):
61
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
62
+ just in case (e.g., 512 or 1024 or 2048).
63
+ type_vocab_size (`int`, *optional*, defaults to 2):
64
+ The vocabulary size of the `token_type_ids` passed when calling [`PoNetModel`] or [`TFPoNetModel`].
65
+ initializer_range (`float`, *optional*, defaults to 0.02):
66
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
68
+ The epsilon used by the layer normalization layers.
69
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
70
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
71
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
72
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
73
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
74
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
75
+ is_decoder (`bool`, *optional*, defaults to `False`):
76
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
77
+ use_cache (`bool`, *optional*, defaults to `True`):
78
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
79
+ relevant if `config.is_decoder=True`.
80
+ classifier_dropout (`float`, *optional*):
81
+ The dropout ratio for the classification head.
82
+
83
+ Examples:
84
+
85
+ ```python
86
+ >>> from transformers import PoNetConfig, PoNetModel
87
+
88
+ >>> # Initializing a PONET chtan/ponet-base-uncased style configuration
89
+ >>> configuration = PoNetConfig()
90
+
91
+ >>> # Initializing a model (with random weights) from the chtan/ponet-base-uncased style configuration
92
+ >>> model = PoNetModel(configuration)
93
+
94
+ >>> # Accessing the model configuration
95
+ >>> configuration = model.config
96
+ ```"""
97
+ model_type = "ponet"
98
+
99
+ def __init__(
100
+ self,
101
+ vocab_size=30522,
102
+ hidden_size=768,
103
+ num_hidden_layers=12,
104
+ num_attention_heads=12,
105
+ intermediate_size=3072,
106
+ hidden_act="gelu",
107
+ hidden_dropout_prob=0.1,
108
+ max_position_embeddings=512,
109
+ type_vocab_size=2,
110
+ initializer_range=0.02,
111
+ layer_norm_eps=1e-12,
112
+ pad_token_id=0,
113
+ position_embedding_type="absolute",
114
+ use_cache=True,
115
+ classifier_dropout=None,
116
+ **kwargs,
117
+ ):
118
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
119
+
120
+ self.vocab_size = vocab_size
121
+ self.hidden_size = hidden_size
122
+ self.num_hidden_layers = num_hidden_layers
123
+ self.num_attention_heads = num_attention_heads
124
+ self.hidden_act = hidden_act
125
+ self.intermediate_size = intermediate_size
126
+ self.hidden_dropout_prob = hidden_dropout_prob
127
+ self.max_position_embeddings = max_position_embeddings
128
+ self.type_vocab_size = type_vocab_size
129
+ self.initializer_range = initializer_range
130
+ self.layer_norm_eps = layer_norm_eps
131
+ self.position_embedding_type = position_embedding_type
132
+ self.use_cache = use_cache
133
+ self.classifier_dropout = classifier_dropout
134
+
135
+
136
+ class PoNetOnnxConfig(OnnxConfig):
137
+ @property
138
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
139
+ if self.task == "multiple-choice":
140
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
141
+ else:
142
+ dynamic_axis = {0: "batch", 1: "sequence"}
143
+ return OrderedDict(
144
+ [
145
+ ("input_ids", dynamic_axis),
146
+ ("attention_mask", dynamic_axis),
147
+ ("token_type_ids", dynamic_axis),
148
+ ]
149
+ )
modeling_ponet.py ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch PONET model."""
17
+
18
+
19
+ import math
20
+ from dataclasses import dataclass
21
+ from typing import Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.modeling_outputs import (
30
+ BaseModelOutput,
31
+ BaseModelOutputWithPooling,
32
+ SequenceClassifierOutput,
33
+ )
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
36
+ from transformers.utils import (
37
+ ModelOutput,
38
+ add_code_sample_docstrings,
39
+ add_start_docstrings,
40
+ add_start_docstrings_to_model_forward,
41
+ logging,
42
+ replace_return_docstrings,
43
+ )
44
+ from .configuration_ponet import PoNetConfig
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+ _CHECKPOINT_FOR_DOC = "ponet-base"
50
+ _CONFIG_FOR_DOC = "PoNetConfig"
51
+
52
+
53
+ PONET_PRETRAINED_MODEL_ARCHIVE_LIST = [
54
+ "chtan/ponet-base-uncased",
55
+ # See all PoNet models at https://huggingface.co/models?filter=ponet
56
+ ]
57
+
58
+ # XXX: get from tokenizer
59
+ CLS_ID = 101
60
+ EOS_ID = 102
61
+
62
+
63
+ def segment_max(src, index, dim=1):
64
+ out = torch.zeros_like(src).scatter_reduce(
65
+ dim, index.unsqueeze(-1).expand_as(src), src, reduce="amax", include_self=False
66
+ )
67
+ dummy = index.unsqueeze(-1).expand(*index.shape[:2], out.size(-1))
68
+ return torch.gather(out, dim, dummy).to(dtype=src.dtype)
69
+
70
+
71
+ def get_segment_index(input_ids, cls_id=CLS_ID, eos_id=EOS_ID):
72
+ mask = (input_ids == cls_id).to(dtype=torch.long) + (input_ids == eos_id).to(dtype=torch.long)
73
+ mask = mask + torch.cat([torch.zeros_like(mask[:, 0:1]), mask[:, :-1]], dim=1)
74
+ num_segments = input_ids[:, :1] == cls_id
75
+ segment_idx = mask.cumsum(dim=1)
76
+ return torch.where(num_segments == 0, segment_idx, segment_idx - 1)
77
+
78
+
79
+ def get_token_type_mask(input_ids, cls_id=CLS_ID, eos_id=EOS_ID):
80
+ mask = (input_ids == cls_id) | (input_ids == eos_id)
81
+ return mask
82
+
83
+
84
+ def get_win_max(hidden_states, kernel_size=3):
85
+ m = nn.MaxPool1d(kernel_size, stride=1, padding=kernel_size // 2)
86
+ out = m(hidden_states.permute(0, 2, 1)).permute(0, 2, 1)
87
+ return out
88
+
89
+
90
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->PoNet
91
+ class PoNetEmbeddings(nn.Module):
92
+ """Construct the embeddings from word, position and token_type embeddings."""
93
+
94
+ def __init__(self, config):
95
+ super().__init__()
96
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
97
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
98
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
99
+
100
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
101
+ # any TensorFlow checkpoint file
102
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
103
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
104
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
105
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
106
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
107
+ self.register_buffer(
108
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
109
+ )
110
+
111
+ def forward(
112
+ self,
113
+ input_ids: Optional[torch.LongTensor] = None,
114
+ token_type_ids: Optional[torch.LongTensor] = None,
115
+ position_ids: Optional[torch.LongTensor] = None,
116
+ inputs_embeds: Optional[torch.FloatTensor] = None,
117
+ past_key_values_length: int = 0,
118
+ ) -> torch.Tensor:
119
+ if input_ids is not None:
120
+ input_shape = input_ids.size()
121
+ else:
122
+ input_shape = inputs_embeds.size()[:-1]
123
+
124
+ seq_length = input_shape[1]
125
+
126
+ if position_ids is None:
127
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
128
+
129
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
130
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
131
+ # issue #5664
132
+ if token_type_ids is None:
133
+ if hasattr(self, "token_type_ids"):
134
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
135
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
136
+ token_type_ids = buffered_token_type_ids_expanded
137
+ else:
138
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
139
+
140
+ if inputs_embeds is None:
141
+ inputs_embeds = self.word_embeddings(input_ids)
142
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
143
+
144
+ embeddings = inputs_embeds + token_type_embeddings
145
+ if self.position_embedding_type == "absolute":
146
+ position_embeddings = self.position_embeddings(position_ids)
147
+ embeddings += position_embeddings
148
+ embeddings = self.LayerNorm(embeddings)
149
+ embeddings = self.dropout(embeddings)
150
+ return embeddings
151
+
152
+
153
+ class PoNetSelfAttention(nn.Module):
154
+ def __init__(self, config):
155
+ super().__init__()
156
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
157
+ raise ValueError(
158
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
159
+ f"heads ({config.num_attention_heads})"
160
+ )
161
+ self.clsgsepg = getattr(config, "clsgsepg", True)
162
+
163
+ self.num_attention_heads = config.num_attention_heads
164
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
165
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
166
+
167
+ self.dense_local = nn.Linear(config.hidden_size, config.hidden_size)
168
+ self.dense_segment = nn.Linear(config.hidden_size, config.hidden_size)
169
+
170
+ self.dense_q = nn.Linear(config.hidden_size, self.all_head_size)
171
+ self.dense_k = nn.Linear(config.hidden_size, self.all_head_size)
172
+ self.dense_o = nn.Linear(config.hidden_size, self.all_head_size)
173
+
174
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
175
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
176
+ x = x.view(new_x_shape)
177
+ return x.permute(0, 2, 1, 3)
178
+
179
+ def forward(
180
+ self,
181
+ hidden_states: torch.Tensor,
182
+ segment_index: torch.LongTensor,
183
+ token_type_mask: torch.LongTensor,
184
+ attention_mask: Optional[torch.FloatTensor] = None,
185
+ output_attentions: Optional[bool] = False,
186
+ ) -> Tuple[torch.Tensor]:
187
+ context_layer_q = self.transpose_for_scores(self.dense_q(hidden_states))
188
+ context_layer_k = self.transpose_for_scores(self.dense_k(hidden_states))
189
+ context_layer_v = context_layer_k
190
+ context_layer_o = self.transpose_for_scores(self.dense_o(hidden_states))
191
+
192
+ if attention_mask is not None:
193
+ _attention_mask = attention_mask.squeeze(1).unsqueeze(-1) < -1
194
+
195
+ if attention_mask is not None:
196
+ context_layer_q.masked_fill_(_attention_mask, 0.0)
197
+ q = context_layer_q.sum(dim=-2) / torch.ones_like(_attention_mask).to(
198
+ dtype=context_layer_q.dtype
199
+ ).masked_fill(_attention_mask, 0.0).sum(dim=-2)
200
+ else:
201
+ q = context_layer_q.mean(dim=-2)
202
+ att = torch.einsum("bdh,bdlh -> bdl", q, context_layer_k) / math.sqrt(context_layer_q.shape[-1])
203
+ if attention_mask is not None:
204
+ att = att + attention_mask.squeeze(1)
205
+ att_prob = att.softmax(dim=-1)
206
+ v = torch.einsum("bdlh,bdl->bdh", context_layer_v, att_prob)
207
+
208
+ context_layer_segment = self.dense_segment(hidden_states)
209
+ context_layer_local = self.dense_local(hidden_states)
210
+ if attention_mask is not None:
211
+ context_layer_local.masked_fill_(_attention_mask.squeeze(1), -10000)
212
+ context_layer_segment.masked_fill_(_attention_mask.squeeze(1), -10000)
213
+
214
+ if self.clsgsepg:
215
+ # XXX: a trick to make sure the segment and local information will not leak
216
+ context_layer_local = get_win_max(
217
+ context_layer_local.masked_fill(token_type_mask.unsqueeze(dim=-1), -10000)
218
+ )
219
+ context_layer_segment = segment_max(context_layer_segment, index=segment_index)
220
+
221
+ context_layer_segment.masked_fill_(token_type_mask.unsqueeze(dim=-1), 0.0)
222
+ context_layer_local.masked_fill_(token_type_mask.unsqueeze(dim=-1), 0.0)
223
+ else:
224
+ context_layer_local = get_win_max(context_layer_local)
225
+ context_layer_segment = segment_max(context_layer_segment, index=segment_index)
226
+
227
+ context_layer_local = self.transpose_for_scores(context_layer_local)
228
+ context_layer_segment = self.transpose_for_scores(context_layer_segment)
229
+
230
+ context_layer = (v.unsqueeze(dim=-2) + context_layer_segment) * context_layer_o + context_layer_local
231
+ context_layer = context_layer.permute(0, 2, 1, 3).reshape(*hidden_states.shape[:2], -1)
232
+
233
+ if attention_mask is not None:
234
+ context_layer.masked_fill_(_attention_mask.squeeze(1), 0.0)
235
+
236
+ outputs = (context_layer, att_prob) if output_attentions else (context_layer,)
237
+ return outputs
238
+
239
+
240
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->PoNet
241
+ class PoNetSelfOutput(nn.Module):
242
+ def __init__(self, config):
243
+ super().__init__()
244
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
245
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
246
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
247
+
248
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
249
+ hidden_states = self.dense(hidden_states)
250
+ hidden_states = self.dropout(hidden_states)
251
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
252
+ return hidden_states
253
+
254
+
255
+ class PoNetAttention(nn.Module):
256
+ def __init__(self, config):
257
+ super().__init__()
258
+ self.self = PoNetSelfAttention(config)
259
+ self.output = PoNetSelfOutput(config)
260
+ self.pruned_heads = set()
261
+
262
+ def prune_heads(self, heads):
263
+ if len(heads) == 0:
264
+ return
265
+ heads, index = find_pruneable_heads_and_indices(
266
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
267
+ )
268
+
269
+ # Prune linear layers
270
+ self.self.query = prune_linear_layer(self.self.query, index)
271
+ self.self.key = prune_linear_layer(self.self.key, index)
272
+ self.self.value = prune_linear_layer(self.self.value, index)
273
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
274
+
275
+ # Update hyper params and store pruned heads
276
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
277
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
278
+ self.pruned_heads = self.pruned_heads.union(heads)
279
+
280
+ def forward(
281
+ self,
282
+ hidden_states: torch.Tensor,
283
+ segment_index: torch.LongTensor,
284
+ token_type_mask: torch.LongTensor,
285
+ attention_mask: Optional[torch.FloatTensor] = None,
286
+ output_attentions: Optional[bool] = False,
287
+ ) -> Tuple[torch.Tensor]:
288
+ self_outputs = self.self(
289
+ hidden_states,
290
+ segment_index,
291
+ token_type_mask,
292
+ attention_mask,
293
+ output_attentions,
294
+ )
295
+ attention_output = self.output(self_outputs[0], hidden_states)
296
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
297
+ return outputs
298
+
299
+
300
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->PoNet
301
+ class PoNetIntermediate(nn.Module):
302
+ def __init__(self, config):
303
+ super().__init__()
304
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
305
+ if isinstance(config.hidden_act, str):
306
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
307
+ else:
308
+ self.intermediate_act_fn = config.hidden_act
309
+
310
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
311
+ hidden_states = self.dense(hidden_states)
312
+ hidden_states = self.intermediate_act_fn(hidden_states)
313
+ return hidden_states
314
+
315
+
316
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->PoNet
317
+ class PoNetOutput(nn.Module):
318
+ def __init__(self, config):
319
+ super().__init__()
320
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
321
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
322
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
323
+
324
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
325
+ hidden_states = self.dense(hidden_states)
326
+ hidden_states = self.dropout(hidden_states)
327
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
328
+ return hidden_states
329
+
330
+
331
+ class PoNetLayer(nn.Module):
332
+ def __init__(self, config):
333
+ super().__init__()
334
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
335
+ self.seq_len_dim = 1
336
+ self.attention = PoNetAttention(config)
337
+
338
+ config.is_decoder = False # XXX: Decoder is not yet impletemented.
339
+ self.is_decoder = config.is_decoder
340
+
341
+ self.intermediate = PoNetIntermediate(config)
342
+ self.output = PoNetOutput(config)
343
+
344
+ def forward(
345
+ self,
346
+ hidden_states: torch.Tensor,
347
+ segment_index: torch.LongTensor,
348
+ token_type_mask: torch.LongTensor,
349
+ attention_mask: Optional[torch.FloatTensor] = None,
350
+ output_attentions: Optional[bool] = False,
351
+ ) -> Tuple[torch.Tensor]:
352
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
353
+ self_attention_outputs = self.attention(
354
+ hidden_states,
355
+ segment_index,
356
+ token_type_mask,
357
+ attention_mask,
358
+ output_attentions=output_attentions,
359
+ )
360
+ attention_output = self_attention_outputs[0]
361
+
362
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
363
+
364
+ layer_output = apply_chunking_to_forward(
365
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
366
+ )
367
+ outputs = (layer_output,) + outputs
368
+
369
+ return outputs
370
+
371
+ def feed_forward_chunk(self, attention_output):
372
+ intermediate_output = self.intermediate(attention_output)
373
+ layer_output = self.output(intermediate_output, attention_output)
374
+ return layer_output
375
+
376
+
377
+ class PoNetEncoder(nn.Module):
378
+ def __init__(self, config):
379
+ super().__init__()
380
+ self.config = config
381
+ self.layer = nn.ModuleList([PoNetLayer(config) for _ in range(config.num_hidden_layers)])
382
+ self.gradient_checkpointing = False
383
+
384
+ def forward(
385
+ self,
386
+ hidden_states: torch.Tensor,
387
+ segment_index: torch.LongTensor,
388
+ token_type_mask: torch.LongTensor,
389
+ attention_mask: Optional[torch.FloatTensor] = None,
390
+ output_attentions: Optional[bool] = False,
391
+ output_hidden_states: Optional[bool] = False,
392
+ return_dict: Optional[bool] = True,
393
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
394
+ all_hidden_states = () if output_hidden_states else None
395
+ all_self_attentions = () if output_attentions else None
396
+
397
+ for i, layer_module in enumerate(self.layer):
398
+ if output_hidden_states:
399
+ all_hidden_states = all_hidden_states + (hidden_states,)
400
+
401
+ if self.gradient_checkpointing and self.training:
402
+
403
+ def create_custom_forward(module):
404
+ def custom_forward(*inputs):
405
+ return module(*inputs, output_attentions)
406
+
407
+ return custom_forward
408
+
409
+ layer_outputs = torch.utils.checkpoint.checkpoint(
410
+ create_custom_forward(layer_module),
411
+ hidden_states,
412
+ segment_index,
413
+ token_type_mask,
414
+ attention_mask,
415
+ )
416
+ else:
417
+ layer_outputs = layer_module(
418
+ hidden_states,
419
+ segment_index,
420
+ token_type_mask,
421
+ attention_mask,
422
+ output_attentions,
423
+ )
424
+
425
+ hidden_states = layer_outputs[0]
426
+ if output_attentions:
427
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
428
+
429
+ if output_hidden_states:
430
+ all_hidden_states = all_hidden_states + (hidden_states,)
431
+
432
+ if not return_dict:
433
+ return tuple(
434
+ v
435
+ for v in [
436
+ hidden_states,
437
+ all_hidden_states,
438
+ all_self_attentions,
439
+ ]
440
+ if v is not None
441
+ )
442
+ return BaseModelOutput(
443
+ last_hidden_state=hidden_states,
444
+ hidden_states=all_hidden_states,
445
+ attentions=all_self_attentions,
446
+ )
447
+
448
+
449
+ # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->PoNet
450
+ class PoNetPooler(nn.Module):
451
+ def __init__(self, config):
452
+ super().__init__()
453
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
454
+ self.activation = nn.Tanh()
455
+
456
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
457
+ # We "pool" the model by simply taking the hidden state corresponding
458
+ # to the first token.
459
+ first_token_tensor = hidden_states[:, 0]
460
+ pooled_output = self.dense(first_token_tensor)
461
+ pooled_output = self.activation(pooled_output)
462
+ return pooled_output
463
+
464
+
465
+ # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->PoNet
466
+ class PoNetPredictionHeadTransform(nn.Module):
467
+ def __init__(self, config):
468
+ super().__init__()
469
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
470
+ if isinstance(config.hidden_act, str):
471
+ self.transform_act_fn = ACT2FN[config.hidden_act]
472
+ else:
473
+ self.transform_act_fn = config.hidden_act
474
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
475
+
476
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
477
+ hidden_states = self.dense(hidden_states)
478
+ hidden_states = self.transform_act_fn(hidden_states)
479
+ hidden_states = self.LayerNorm(hidden_states)
480
+ return hidden_states
481
+
482
+
483
+ # Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->PoNet
484
+ class PoNetLMPredictionHead(nn.Module):
485
+ def __init__(self, config):
486
+ super().__init__()
487
+ self.transform = PoNetPredictionHeadTransform(config)
488
+
489
+ # The output weights are the same as the input embeddings, but there is
490
+ # an output-only bias for each token.
491
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
492
+
493
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
494
+
495
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
496
+ self.decoder.bias = self.bias
497
+
498
+ def forward(self, hidden_states):
499
+ hidden_states = self.transform(hidden_states)
500
+ hidden_states = self.decoder(hidden_states)
501
+ return hidden_states
502
+
503
+
504
+ # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->PoNet
505
+ class PoNetOnlyMLMHead(nn.Module):
506
+ def __init__(self, config):
507
+ super().__init__()
508
+ self.predictions = PoNetLMPredictionHead(config)
509
+
510
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
511
+ prediction_scores = self.predictions(sequence_output)
512
+ return prediction_scores
513
+
514
+
515
+ # Copied from transformers.models.bert.modeling_bert.BertOnlyNSPHead with Bert->PoNet
516
+ class PoNetOnlyNSPHead(nn.Module):
517
+ def __init__(self, config):
518
+ super().__init__()
519
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
520
+
521
+ def forward(self, pooled_output):
522
+ seq_relationship_score = self.seq_relationship(pooled_output)
523
+ return seq_relationship_score
524
+
525
+
526
+ class PoNetOnlySSOHead(nn.Module):
527
+ def __init__(self, config):
528
+ super().__init__()
529
+ self.seq_relationship = nn.Linear(config.hidden_size, 3)
530
+
531
+ def forward(self, pooled_output):
532
+ seq_relationship_score = self.seq_relationship(pooled_output)
533
+ return seq_relationship_score
534
+
535
+
536
+ class PoNetPreTrainingHeads(nn.Module):
537
+ def __init__(self, config):
538
+ super().__init__()
539
+ self.predictions = PoNetLMPredictionHead(config)
540
+ self.seq_relationship = nn.Linear(config.hidden_size, 3) # 3 classes: sentence structural objective (SSO)
541
+
542
+ def forward(self, sequence_output, pooled_output):
543
+ prediction_scores = self.predictions(sequence_output)
544
+ seq_relationship_score = self.seq_relationship(pooled_output)
545
+ return prediction_scores, seq_relationship_score
546
+
547
+
548
+ class PoNetPreTrainedModel(PreTrainedModel):
549
+ """
550
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
551
+ models.
552
+ """
553
+
554
+ config_class = PoNetConfig
555
+ base_model_prefix = "ponet"
556
+ supports_gradient_checkpointing = True
557
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
558
+
559
+ def _init_weights(self, module):
560
+ """Initialize the weights"""
561
+ if isinstance(module, nn.Linear):
562
+ # Slightly different from the TF version which uses truncated_normal for initialization
563
+ # cf https://github.com/pytorch/pytorch/pull/5617
564
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
565
+ if module.bias is not None:
566
+ module.bias.data.zero_()
567
+ elif isinstance(module, nn.Embedding):
568
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
569
+ if module.padding_idx is not None:
570
+ module.weight.data[module.padding_idx].zero_()
571
+ elif isinstance(module, nn.LayerNorm):
572
+ module.bias.data.zero_()
573
+ module.weight.data.fill_(1.0)
574
+
575
+ def _set_gradient_checkpointing(self, module, value=False):
576
+ if isinstance(module, PoNetEncoder):
577
+ module.gradient_checkpointing = value
578
+
579
+
580
+ @dataclass
581
+ class PoNetForPreTrainingOutput(ModelOutput):
582
+ """
583
+ Output type of [*PoNetForPreTraining*].
584
+
585
+ Args:
586
+ loss (*optional*, returned when *labels* is provided, *torch.FloatTensor* of shape *(1,)*):
587
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
588
+ (classification) loss.
589
+ mlm_loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
590
+ Masked language modeling loss.
591
+ sso_loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
592
+ sso loss.
593
+ prediction_logits (*torch.FloatTensor* of shape *(batch_size, sequence_length, config.vocab_size)*):
594
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
595
+ seq_relationship_logits (*torch.FloatTensor* of shape *(batch_size, 3)*):
596
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
597
+ before SoftMax).
598
+ hidden_states (*tuple(torch.FloatTensor)*, *optional*, returned when *output_hidden_states=True* is passed or when *config.output_hidden_states=True*):
599
+ Tuple of *torch.FloatTensor* (one for the output of the embeddings + one for the output of each layer) of
600
+ shape *(batch_size, sequence_length, hidden_size)*.
601
+
602
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
603
+ attentions (*tuple(torch.FloatTensor)*, *optional*, returned when *output_attentions=True* is passed or when *config.output_attentions=True*):
604
+ Tuple of *torch.FloatTensor* (one for each layer) of shape *(batch_size, num_heads, sequence_length,
605
+ sequence_length)*.
606
+
607
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
608
+ heads.
609
+ """
610
+
611
+ loss: Optional[torch.FloatTensor] = None
612
+ mlm_loss: Optional[torch.FloatTensor] = None
613
+ sso_loss: Optional[torch.FloatTensor] = None
614
+ prediction_logits: torch.FloatTensor = None
615
+ seq_relationship_logits: torch.FloatTensor = None
616
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
617
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
618
+
619
+
620
+ PONET_START_DOCSTRING = r"""
621
+
622
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
623
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
624
+ etc.)
625
+
626
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
627
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
628
+ and behavior.
629
+
630
+ Parameters:
631
+ config ([`PoNetConfig`]): Model configuration class with all the parameters of the model.
632
+ Initializing with a config file does not load the weights associated with the model, only the
633
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
634
+ """
635
+
636
+ PONET_INPUTS_DOCSTRING = r"""
637
+ Args:
638
+ input_ids (`torch.LongTensor` of shape `({0})`):
639
+ Indices of input sequence tokens in the vocabulary.
640
+
641
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
642
+ [`PreTrainedTokenizer.__call__`] for details.
643
+
644
+ [What are input IDs?](../glossary#input-ids)
645
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
646
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
647
+
648
+ - 1 for tokens that are **not masked**,
649
+ - 0 for tokens that are **masked**.
650
+
651
+ [What are attention masks?](../glossary#attention-mask)
652
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
653
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
654
+ 1]`:
655
+
656
+ - 0 corresponds to a *sentence A* token,
657
+ - 1 corresponds to a *sentence B* token.
658
+
659
+ [What are token type IDs?](../glossary#token-type-ids)
660
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
661
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
662
+ config.max_position_embeddings - 1]`.
663
+
664
+ [What are position IDs?](../glossary#position-ids)
665
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
666
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
667
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
668
+ model's internal embedding lookup matrix.
669
+ output_attentions (`bool`, *optional*):
670
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
671
+ tensors for more detail.
672
+ output_hidden_states (`bool`, *optional*):
673
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
674
+ more detail.
675
+ return_dict (`bool`, *optional*):
676
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
677
+ """
678
+
679
+
680
+ @add_start_docstrings(
681
+ "The bare PoNet Model transformer outputting raw hidden-states without any specific head on top.",
682
+ PONET_START_DOCSTRING,
683
+ )
684
+ class PoNetModel(PoNetPreTrainedModel):
685
+ def __init__(self, config, add_pooling_layer=True):
686
+ super().__init__(config)
687
+ self.config = config
688
+
689
+ self.embeddings = PoNetEmbeddings(config)
690
+ self.encoder = PoNetEncoder(config)
691
+
692
+ self.pooler = PoNetPooler(config) if add_pooling_layer else None
693
+
694
+ # Initialize weights and apply final processing
695
+ self.post_init()
696
+
697
+ def get_input_embeddings(self):
698
+ return self.embeddings.word_embeddings
699
+
700
+ def set_input_embeddings(self, value):
701
+ self.embeddings.word_embeddings = value
702
+
703
+ def _prune_heads(self, heads_to_prune):
704
+ """
705
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
706
+ class PreTrainedModel
707
+ """
708
+ for layer, heads in heads_to_prune.items():
709
+ self.encoder.layer[layer].attention.prune_heads(heads)
710
+
711
+ @add_start_docstrings_to_model_forward(PONET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
712
+ @add_code_sample_docstrings(
713
+ checkpoint=_CHECKPOINT_FOR_DOC,
714
+ output_type=BaseModelOutputWithPooling,
715
+ config_class=_CONFIG_FOR_DOC,
716
+ )
717
+ def forward(
718
+ self,
719
+ input_ids: Optional[torch.Tensor] = None,
720
+ attention_mask: Optional[torch.Tensor] = None,
721
+ token_type_ids: Optional[torch.Tensor] = None,
722
+ segment_ids: Optional[torch.Tensor] = None,
723
+ position_ids: Optional[torch.Tensor] = None,
724
+ inputs_embeds: Optional[torch.Tensor] = None,
725
+ output_attentions: Optional[bool] = None,
726
+ output_hidden_states: Optional[bool] = None,
727
+ return_dict: Optional[bool] = None,
728
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
729
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
730
+ output_hidden_states = (
731
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
732
+ )
733
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
734
+
735
+ if input_ids is not None and inputs_embeds is not None:
736
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
737
+ elif input_ids is not None:
738
+ input_shape = input_ids.size()
739
+ elif inputs_embeds is not None:
740
+ input_shape = inputs_embeds.size()[:-1]
741
+ else:
742
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
743
+
744
+ batch_size, seq_length = input_shape
745
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
746
+
747
+ if attention_mask is None:
748
+ attention_mask = torch.ones(((batch_size, seq_length)), device=device)
749
+
750
+ if token_type_ids is None:
751
+ if hasattr(self.embeddings, "token_type_ids"):
752
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
753
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
754
+ token_type_ids = buffered_token_type_ids_expanded
755
+ else:
756
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
757
+
758
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
759
+ # ourselves in which case we just need to make it broadcastable to all heads.
760
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
761
+
762
+ embedding_output = self.embeddings(
763
+ input_ids=input_ids,
764
+ position_ids=position_ids,
765
+ token_type_ids=token_type_ids,
766
+ inputs_embeds=inputs_embeds,
767
+ )
768
+
769
+ segment_index = get_segment_index(input_ids) if segment_ids is None else segment_ids
770
+ token_type_mask = get_token_type_mask(input_ids)
771
+ encoder_outputs = self.encoder(
772
+ embedding_output,
773
+ segment_index,
774
+ token_type_mask,
775
+ attention_mask=extended_attention_mask,
776
+ output_attentions=output_attentions,
777
+ output_hidden_states=output_hidden_states,
778
+ return_dict=return_dict,
779
+ )
780
+ sequence_output = encoder_outputs[0]
781
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
782
+
783
+ if not return_dict:
784
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
785
+
786
+ return BaseModelOutputWithPooling(
787
+ last_hidden_state=sequence_output,
788
+ pooler_output=pooled_output,
789
+ hidden_states=encoder_outputs.hidden_states,
790
+ attentions=encoder_outputs.attentions,
791
+ )
792
+
793
+
794
+ @add_start_docstrings(
795
+ """
796
+ PoNet Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
797
+ sentence prediction (classification)` head.
798
+ """,
799
+ PONET_START_DOCSTRING,
800
+ )
801
+ class PoNetForPreTraining(PoNetPreTrainedModel):
802
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"cls.predictions.decoder.bias"]
803
+
804
+ def __init__(self, config):
805
+ super().__init__(config)
806
+
807
+ self.ponet = PoNetModel(config)
808
+ self.cls = PoNetPreTrainingHeads(config)
809
+
810
+ # Initialize weights and apply final processing
811
+ self.post_init()
812
+
813
+ @add_start_docstrings_to_model_forward(PONET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
814
+ @replace_return_docstrings(output_type=PoNetForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
815
+ def forward(
816
+ self,
817
+ input_ids: Optional[torch.Tensor] = None,
818
+ attention_mask: Optional[torch.Tensor] = None,
819
+ token_type_ids: Optional[torch.Tensor] = None,
820
+ segment_ids: Optional[torch.Tensor] = None,
821
+ position_ids: Optional[torch.Tensor] = None,
822
+ inputs_embeds: Optional[torch.Tensor] = None,
823
+ labels: Optional[torch.Tensor] = None,
824
+ sentence_structural_label: Optional[torch.Tensor] = None,
825
+ output_attentions: Optional[bool] = None,
826
+ output_hidden_states: Optional[bool] = None,
827
+ return_dict: Optional[bool] = None,
828
+ ) -> Union[Tuple[torch.Tensor], PoNetForPreTrainingOutput]:
829
+ r"""
830
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
831
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
832
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
833
+ the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
834
+ sentence_structural_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
835
+ Labels for computing the sentence structural objective (classification) loss. Input should be a
836
+ sequence pair (see `input_ids` docstring) Indices should be in `[0, 1, 2]`:
837
+
838
+ - 0 indicates sequence B is a continuation of sequence A,
839
+ - 1 indicates sequence A is a continuation of sequence B,
840
+ - 2 indicates sequence B is a random sequence.
841
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
842
+ Used to hide legacy arguments that have been deprecated.
843
+
844
+ Returns:
845
+
846
+ Example:
847
+
848
+ ```python
849
+ >>> from transformers import AutoTokenizer, PoNetForPreTraining
850
+ >>> import torch
851
+
852
+ >>> tokenizer = AutoTokenizer.from_pretrained("ponet-base")
853
+ >>> model = PoNetForPreTraining.from_pretrained("ponet-base")
854
+
855
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
856
+ >>> outputs = model(**inputs)
857
+
858
+ >>> prediction_logits = outputs.prediction_logits
859
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
860
+ ```
861
+ """
862
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
863
+
864
+ outputs = self.ponet(
865
+ input_ids,
866
+ attention_mask=attention_mask,
867
+ token_type_ids=token_type_ids,
868
+ segment_ids=segment_ids,
869
+ position_ids=position_ids,
870
+ inputs_embeds=inputs_embeds,
871
+ output_attentions=output_attentions,
872
+ output_hidden_states=output_hidden_states,
873
+ return_dict=return_dict,
874
+ )
875
+
876
+ sequence_output, pooled_output = outputs[:2]
877
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
878
+
879
+ total_loss = None
880
+ masked_lm_loss = None
881
+ sso_loss = None
882
+ if labels is not None and sentence_structural_label is not None:
883
+ loss_fct = CrossEntropyLoss()
884
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
885
+ sso_loss = loss_fct(seq_relationship_score.view(-1, 3), sentence_structural_label.view(-1))
886
+ total_loss = masked_lm_loss + sso_loss
887
+
888
+ if not return_dict:
889
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
890
+ return ((total_loss, masked_lm_loss, sso_loss) + output) if total_loss is not None else output
891
+
892
+ return PoNetForPreTrainingOutput(
893
+ loss=total_loss,
894
+ mlm_loss=masked_lm_loss,
895
+ sso_loss=sso_loss,
896
+ prediction_logits=prediction_scores,
897
+ seq_relationship_logits=seq_relationship_score,
898
+ hidden_states=outputs.hidden_states,
899
+ attentions=outputs.attentions,
900
+ )
901
+
902
+
903
+ @add_start_docstrings(
904
+ """
905
+ PoNet Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
906
+ output) e.g. for GLUE tasks.
907
+ """,
908
+ PONET_START_DOCSTRING,
909
+ )
910
+ class PoNetForSequenceClassification(PoNetPreTrainedModel):
911
+ def __init__(self, config):
912
+ super().__init__(config)
913
+ self.num_labels = config.num_labels
914
+ self.config = config
915
+
916
+ self.ponet = PoNetModel(config)
917
+ classifier_dropout = (
918
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
919
+ )
920
+ self.dropout = nn.Dropout(classifier_dropout)
921
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
922
+
923
+ # Initialize weights and apply final processing
924
+ self.post_init()
925
+
926
+ @add_start_docstrings_to_model_forward(PONET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
927
+ @add_code_sample_docstrings(
928
+ output_type=SequenceClassifierOutput,
929
+ config_class=_CONFIG_FOR_DOC,
930
+ )
931
+ def forward(
932
+ self,
933
+ input_ids: Optional[torch.Tensor] = None,
934
+ attention_mask: Optional[torch.Tensor] = None,
935
+ token_type_ids: Optional[torch.Tensor] = None,
936
+ segment_ids: Optional[torch.Tensor] = None,
937
+ position_ids: Optional[torch.Tensor] = None,
938
+ inputs_embeds: Optional[torch.Tensor] = None,
939
+ labels: Optional[torch.Tensor] = None,
940
+ output_attentions: Optional[bool] = None,
941
+ output_hidden_states: Optional[bool] = None,
942
+ return_dict: Optional[bool] = None,
943
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
944
+ r"""
945
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
946
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
947
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
948
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
949
+ """
950
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
951
+
952
+ outputs = self.ponet(
953
+ input_ids,
954
+ attention_mask=attention_mask,
955
+ token_type_ids=token_type_ids,
956
+ segment_ids=segment_ids,
957
+ position_ids=position_ids,
958
+ inputs_embeds=inputs_embeds,
959
+ output_attentions=output_attentions,
960
+ output_hidden_states=output_hidden_states,
961
+ return_dict=return_dict,
962
+ )
963
+
964
+ pooled_output = outputs[1]
965
+
966
+ pooled_output = self.dropout(pooled_output)
967
+ logits = self.classifier(pooled_output)
968
+
969
+ loss = None
970
+ if labels is not None:
971
+ if self.config.problem_type is None:
972
+ if self.num_labels == 1:
973
+ self.config.problem_type = "regression"
974
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
975
+ self.config.problem_type = "single_label_classification"
976
+ else:
977
+ self.config.problem_type = "multi_label_classification"
978
+
979
+ if self.config.problem_type == "regression":
980
+ loss_fct = MSELoss()
981
+ if self.num_labels == 1:
982
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
983
+ else:
984
+ loss = loss_fct(logits, labels)
985
+ elif self.config.problem_type == "single_label_classification":
986
+ loss_fct = CrossEntropyLoss()
987
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
988
+ elif self.config.problem_type == "multi_label_classification":
989
+ loss_fct = BCEWithLogitsLoss()
990
+ loss = loss_fct(logits, labels)
991
+ if not return_dict:
992
+ output = (logits,) + outputs[2:]
993
+ return ((loss,) + output) if loss is not None else output
994
+
995
+ return SequenceClassifierOutput(
996
+ loss=loss,
997
+ logits=logits,
998
+ hidden_states=outputs.hidden_states,
999
+ attentions=outputs.attentions,
1000
+ )
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:63250ca16370afadd1c7c65eff9c8c3ab2c3b5718ea3c6042fa874266be5279f
3
- size 497169377
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7641471f11332c2344c83a7c15efdcd3d4e05b1d693d40006f37b71ce69d6627
3
+ size 590977869