kiddothe2b commited on
Commit
2438a16
1 Parent(s): e8eab0c

Add HAT implementation files

Browse files
Files changed (3) hide show
  1. configuration_hat.py +150 -0
  2. modelling_hat.py +0 -0
  3. tokenization_hat.py +249 -0
configuration_hat.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ """ HAT configuration"""
14
+ from collections import OrderedDict
15
+ from typing import Mapping
16
+
17
+ from transformers.onnx import OnnxConfig
18
+ from transformers.utils import logging
19
+ from transformers import PretrainedConfig
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ HAT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25
+ "kiddothe2b/hierarchical-transformer-base-4096": "https://huggingface.co/kiddothe2b/hierarchical-transformer-base-4096/resolve/main/config.json",
26
+ "kiddothe2b/adhoc-hierarchical-transformer-base-4096": "https://huggingface.co/kiddothe2b/adhoc-hierarchical-transformer-base-4096/resolve/main/config.json",
27
+ }
28
+
29
+
30
+ class HATConfig(PretrainedConfig):
31
+ r"""
32
+ This is the configuration class to store the configuration of a :class:`~transformers.HAT`.
33
+ It is used to instantiate a HAT model according to the specified arguments,
34
+ defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
35
+ to that of the HAT `kiddothe2b/hierarchical-transformer-base-4096
36
+ <https://huggingface.co/kiddothe2b/hierarchical-transformer-base-4096>`__ architecture.
37
+
38
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
39
+ outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
40
+
41
+
42
+ Args:
43
+ vocab_size (:obj:`int`, `optional`, defaults to 30522):
44
+ Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
45
+ :obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or
46
+ :class:`~transformers.TFBertModel`.
47
+ max_sentences (:obj:`int`, `optional`, defaults to 64):
48
+ The maximum number of sentences that this model might ever be used with.
49
+ max_sentence_size (:obj:`int`, `optional`, defaults to 128):
50
+ The maximum sentence length that this model might ever be used with.
51
+ model_max_length (:obj:`int`, `optional`, defaults to 8192):
52
+ The maximum sequence length (max_sentences * max_sentence_size) that this model might ever be used with
53
+ encoder_layout (:obj:`Dict`):
54
+ The sentence/document encoder layout.
55
+ hidden_size (:obj:`int`, `optional`, defaults to 768):
56
+ Dimensionality of the encoder layers and the pooler layer.
57
+ num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
58
+ Number of hidden layers in the Transformer encoder.
59
+ num_attention_heads (:obj:`int`, `optional`, defaults to 12):
60
+ Number of attention heads for each attention layer in the Transformer encoder.
61
+ intermediate_size (:obj:`int`, `optional`, defaults to 3072):
62
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
63
+ hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
64
+ The non-linear activation function (function or string) in the encoder and pooler. If string,
65
+ :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
66
+ hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
67
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
68
+ attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
69
+ The dropout ratio for the attention probabilities.
70
+ max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
71
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
72
+ just in case (e.g., 512 or 1024 or 2048).
73
+ type_vocab_size (:obj:`int`, `optional`, defaults to 2):
74
+ The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or
75
+ :class:`~transformers.TFBertModel`.
76
+ initializer_range (:obj:`float`, `optional`, defaults to 0.02):
77
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
78
+ layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
79
+ The epsilon used by the layer normalization layers.
80
+ position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
81
+ Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
82
+ :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
83
+ :obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
84
+ <https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
85
+ `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
86
+ <https://arxiv.org/abs/2009.13658>`__.
87
+ use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
88
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
89
+ relevant if ``config.is_decoder=True``.
90
+ classifier_dropout (:obj:`float`, `optional`):
91
+ The dropout ratio for the classification head.
92
+ """
93
+ model_type = "hierarchical-transformer"
94
+
95
+ def __init__(
96
+ self,
97
+ vocab_size=30522,
98
+ hidden_size=768,
99
+ max_sentences=64,
100
+ max_sentence_size=128,
101
+ model_max_length=8192,
102
+ num_hidden_layers=12,
103
+ num_attention_heads=12,
104
+ intermediate_size=3072,
105
+ hidden_act="gelu",
106
+ hidden_dropout_prob=0.1,
107
+ attention_probs_dropout_prob=0.1,
108
+ max_position_embeddings=512,
109
+ type_vocab_size=2,
110
+ initializer_range=0.02,
111
+ layer_norm_eps=1e-12,
112
+ pad_token_id=0,
113
+ position_embedding_type="absolute",
114
+ encoder_layout=None,
115
+ use_cache=True,
116
+ classifier_dropout=None,
117
+ **kwargs
118
+ ):
119
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
120
+
121
+ self.vocab_size = vocab_size
122
+ self.hidden_size = hidden_size
123
+ self.max_sentences = max_sentences
124
+ self.max_sentence_size = max_sentence_size
125
+ self.model_max_length = model_max_length
126
+ self.encoder_layout = encoder_layout
127
+ self.num_hidden_layers = num_hidden_layers
128
+ self.num_attention_heads = num_attention_heads
129
+ self.hidden_act = hidden_act
130
+ self.intermediate_size = intermediate_size
131
+ self.hidden_dropout_prob = hidden_dropout_prob
132
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
133
+ self.max_position_embeddings = max_position_embeddings
134
+ self.type_vocab_size = type_vocab_size
135
+ self.initializer_range = initializer_range
136
+ self.layer_norm_eps = layer_norm_eps
137
+ self.position_embedding_type = position_embedding_type
138
+ self.use_cache = use_cache
139
+ self.classifier_dropout = classifier_dropout
140
+
141
+
142
+ class HATOnnxConfig(OnnxConfig):
143
+ @property
144
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
145
+ return OrderedDict(
146
+ [
147
+ ("input_ids", {0: "batch", 1: "sequence"}),
148
+ ("attention_mask", {0: "batch", 1: "sequence"}),
149
+ ]
150
+ )
modelling_hat.py ADDED
The diff for this file is too large to render. See raw diff
 
tokenization_hat.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ """Tokenization classes for HAT."""
14
+ import torch
15
+ from transformers import RobertaTokenizer, BertTokenizer
16
+ from .configuration_hat import HATConfig
17
+ from transformers.utils import logging
18
+ try:
19
+ from nltk import sent_tokenize
20
+ except:
21
+ raise Exception('NLTK is not installed! Install it with `pip install nltk`...')
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class HATTokenizer:
26
+ def __init__(self, tokenizer=None):
27
+ self._tokenizer = tokenizer
28
+ self.config = HATConfig.from_pretrained(self._tokenizer.name_or_path)
29
+ self._tokenizer.model_max_length = self.model_max_length
30
+ self.type2id = {'input_ids': (self._tokenizer.cls_token_id, self._tokenizer.pad_token_id),
31
+ 'token_type_ids': (0, 0),
32
+ 'attention_mask': (1, 0),
33
+ 'special_tokens_mask': (1, -100)}
34
+
35
+ @property
36
+ def model_max_length(self):
37
+ return self.config.model_max_length
38
+
39
+ @property
40
+ def mask_token(self):
41
+ return self._tokenizer.mask_token
42
+
43
+ @property
44
+ def mask_token_id(self):
45
+ return self._tokenizer.mask_token_id
46
+
47
+ @property
48
+ def pad_token_id(self):
49
+ return self._tokenizer.pad_token_id
50
+
51
+ @property
52
+ def cls_token_id(self):
53
+ return self._tokenizer.cls_token_id
54
+
55
+ @property
56
+ def sep_token_id(self):
57
+ return self._tokenizer.sep_token_id
58
+
59
+ @property
60
+ def vocab(self):
61
+ return self._tokenizer.vocab
62
+
63
+ def __len__(self):
64
+ """
65
+ Size of the full vocabulary with the added tokens.
66
+ """
67
+ return len(self._tokenizer)
68
+
69
+ def pad(self, *args, **kwargs):
70
+ return self._tokenizer.pad(*args, **kwargs)
71
+
72
+ def convert_tokens_to_ids(self, *args, **kwargs):
73
+ return self._tokenizer.convert_tokens_to_ids(*args, **kwargs)
74
+
75
+ def batch_decode(self, *args, **kwargs):
76
+ return self._tokenizer.batch_decode(*args, **kwargs)
77
+
78
+ def decode(self, *args, **kwargs):
79
+ return self._tokenizer.decode(*args, **kwargs)
80
+
81
+ def tokenize(self, text, **kwargs):
82
+ return self._tokenizer.tokenize(text, **kwargs)
83
+
84
+ def encode(self, text, **kwargs):
85
+ input_ids = self._tokenizer.encode_plus(text, add_special_tokens=False, **kwargs)
86
+ input_ids = self.chunks(input_ids[: self.model_max_length - self.config.max_sentences],
87
+ chunk_size=self.config.max_sentence_length, special_id=self.type2id['input_ids'])
88
+ return input_ids
89
+
90
+ def get_special_tokens_mask(self, *args, **kwargs):
91
+ return self._tokenizer.get_special_tokens_mask(*args, **kwargs)
92
+
93
+ @classmethod
94
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
95
+ try:
96
+ tokenizer = RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
97
+ except:
98
+ tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
99
+ return cls(tokenizer=tokenizer)
100
+
101
+ def save_pretrained(self, *args, **kwargs):
102
+ return self._tokenizer.save_pretrained( *args, **kwargs)
103
+
104
+ def __call__(self, text, **kwargs):
105
+ greedy_chunking = kwargs.pop('greedy_chunking', None)
106
+ text_pair = kwargs.pop('text_pair', None)
107
+ if isinstance(text[0], list):
108
+ batch = self.auto_chunking(text, **kwargs)
109
+ elif greedy_chunking:
110
+ # fixed uniform chunking
111
+ batch = self.uniform_chunking(text, **kwargs)
112
+ else:
113
+ # dynamic sentence splitting and grouping
114
+ batch = self.sentence_splitting(text, **kwargs)
115
+
116
+ if text_pair:
117
+ batch_b = self._tokenizer(text_pair, add_special_tokens=False,
118
+ padding=False, truncation=False)
119
+ for idx, sample in enumerate(batch['input_ids']):
120
+ n_sentences = sum(sample[::self.config.max_sentence_size])
121
+ for input_key in batch:
122
+ batch[input_key][idx][self.config.max_sentence_size * n_sentences:
123
+ self.config.max_sentence_size * (n_sentences + 1)] = \
124
+ self.pad_sentence(batch_b[input_key][idx],
125
+ special_id=(self.sep_token_id, self.pad_token_id)
126
+ if input_key == 'input_ids' else self.type2id[input_key])
127
+
128
+ return batch
129
+
130
+ def uniform_chunking(self, texts, **kwargs):
131
+ original_batch = self._tokenizer(texts, add_special_tokens=False, **kwargs)
132
+ batch = {input_type: [] for input_type in original_batch}
133
+ for input_type in original_batch:
134
+ fixed_batch = []
135
+ for example in original_batch[input_type]:
136
+ fixed_batch.append(self.chunks(example[: self.model_max_length - self.config.max_sentences],
137
+ chunk_size=self.config.max_sentence_length,
138
+ special_id=self.type2id[input_type]))
139
+ batch[input_type] = fixed_batch if isinstance(fixed_batch[0], list) else torch.stack(fixed_batch)
140
+
141
+ if kwargs['padding']:
142
+ batch = self.pad(batch,
143
+ padding=kwargs['padding'],
144
+ max_length=kwargs['max_length'],
145
+ pad_to_multiple_of=kwargs['max_length'])
146
+
147
+ return batch
148
+
149
+ def auto_chunking(self, texts, **kwargs):
150
+ batch = {}
151
+ for text_idx, text in enumerate(texts):
152
+ example_batch = self._tokenizer(text, add_special_tokens=False, **kwargs)
153
+ for input_key in example_batch:
154
+ key_inputs_list = []
155
+ for idx, example in enumerate(example_batch[input_key][:self.config.max_sentences]):
156
+ key_inputs_list.append(self.pad_sentence(example, special_id=self.type2id[input_key]))
157
+ if isinstance(key_inputs_list[0], list):
158
+ key_inputs_list = [token for sentence in key_inputs_list for token in sentence]
159
+ else:
160
+ key_inputs_list = torch.stack(key_inputs_list)
161
+ if input_key in batch:
162
+ batch[input_key].append(key_inputs_list)
163
+ else:
164
+ batch[input_key] = [key_inputs_list]
165
+
166
+ if kwargs['padding']:
167
+ batch = self.pad(batch,
168
+ padding=kwargs['padding'],
169
+ max_length=kwargs['max_length'],
170
+ pad_to_multiple_of=kwargs['max_length'])
171
+
172
+ return batch
173
+
174
+ def chunks(self, flat_inputs, chunk_size=128, special_id=0):
175
+ if isinstance(flat_inputs, list):
176
+ return self.list_chunks(flat_inputs, chunk_size, special_id)
177
+ else:
178
+ return self.tensor_chunks(flat_inputs, chunk_size, special_id)
179
+
180
+ def list_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)):
181
+ """Yield successive n-sized chunks from lst."""
182
+ structured_inputs = [[special_id[0] if sum(flat_inputs[i:i + chunk_size-1]) else special_id[1]]
183
+ + flat_inputs[i:i + chunk_size-1] for i in range(0, len(flat_inputs), chunk_size-1)]
184
+ return [token_input for sentence_inputs in structured_inputs for token_input in sentence_inputs]
185
+
186
+ def tensor_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)):
187
+ """Yield successive n-sized chunks from lst."""
188
+ structured_inputs = torch.stack([torch.cat((torch.tensor([special_id[0] if flat_inputs[i:i + chunk_size-1].sum() else special_id[1]], dtype=torch.int),
189
+ flat_inputs[i:i + chunk_size-1])) for i in range(0, len(flat_inputs), chunk_size-1)])
190
+ return structured_inputs.reshape(-1)
191
+
192
+ def sentence_splitting(self, texts, **kwargs):
193
+ fixed_batch = []
194
+ doc_out = {}
195
+ for text in texts:
196
+ # sentence splitting
197
+ sentences = sent_tokenize(text)
198
+ # tokenization of sentences
199
+ sentences = self._tokenizer(sentences, add_special_tokens=False, padding=False, truncation=False)
200
+ # sentence grouping - merging short sentences to minimize padding
201
+ doc_out = self.sentence_grouping(sentences)
202
+ fixed_batch.append(doc_out)
203
+ # batchify examples
204
+ batch = {input_type: [] for input_type in doc_out}
205
+ for input_type in batch:
206
+ batch[input_type] = [example[input_type] for example in fixed_batch]
207
+ if not isinstance(batch[input_type][0], list):
208
+ batch[input_type] = torch.stack(batch[input_type])
209
+
210
+ if kwargs['padding']:
211
+ batch = self.pad(batch,
212
+ padding=kwargs['padding'],
213
+ max_length=kwargs['max_length'],
214
+ pad_to_multiple_of=kwargs['max_length'])
215
+
216
+ return batch
217
+
218
+ def sentence_grouping(self, sentences):
219
+ doc_out = {input_type: [] for input_type in sentences}
220
+ for input_type in sentences:
221
+ tmp_doc = []
222
+ tmp_sentence = []
223
+ for example in sentences[input_type]:
224
+ if len(tmp_doc) >= self.config.max_sentences:
225
+ break
226
+ if len(tmp_sentence) + len(example) <= self.config.max_sentence_length - 1:
227
+ tmp_sentence.extend(example)
228
+ else:
229
+ tmp_doc.append(self.pad_sentence(tmp_sentence if len(tmp_sentence) else example,
230
+ chunk_size=self.config.max_sentence_length,
231
+ special_id=self.type2id[input_type]))
232
+ tmp_sentence = example if len(tmp_sentence) else example[self.config.max_sentence_length:]
233
+ if len(tmp_sentence) and len(tmp_doc) < self.config.max_sentences:
234
+ tmp_doc.append(self.pad_sentence(tmp_sentence,
235
+ chunk_size=self.config.max_sentence_length,
236
+ special_id=self.type2id[input_type]))
237
+ doc_out[input_type] = [token for sentence in tmp_doc for token in sentence]
238
+ return doc_out
239
+
240
+ def pad_sentence(self, flat_input, chunk_size=128, special_id=(0, 0)):
241
+ if isinstance(flat_input, list):
242
+ return [special_id[0]] + flat_input[:chunk_size-1] + [self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1)
243
+ else:
244
+ return torch.cat((torch.tensor([special_id[0] if flat_input[:chunk_size-1].sum()
245
+ else special_id[1]], dtype=torch.int),
246
+ flat_input[:chunk_size-1],
247
+ torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
248
+ ))
249
+