kiddothe2b commited on
Commit
278b6ef
1 Parent(s): 9a04fd2

Add HAT implementation files

Browse files
Files changed (3) hide show
  1. configuration_hat.py +150 -0
  2. modelling_hat.py +0 -0
  3. tokenization_hat.py +244 -0
configuration_hat.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ """ HAT configuration"""
14
+ from collections import OrderedDict
15
+ from typing import Mapping
16
+
17
+ from transformers.onnx import OnnxConfig
18
+ from transformers.utils import logging
19
+ from transformers import PretrainedConfig
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ HAT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25
+ "kiddothe2b/hierarchical-transformer-base-4096": "https://huggingface.co/kiddothe2b/hierarchical-transformer-base-4096/resolve/main/config.json",
26
+ "kiddothe2b/adhoc-hierarchical-transformer-base-4096": "https://huggingface.co/kiddothe2b/adhoc-hierarchical-transformer-base-4096/resolve/main/config.json",
27
+ }
28
+
29
+
30
+ class HATConfig(PretrainedConfig):
31
+ r"""
32
+ This is the configuration class to store the configuration of a :class:`~transformers.HAT`.
33
+ It is used to instantiate a HAT model according to the specified arguments,
34
+ defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
35
+ to that of the HAT `kiddothe2b/hierarchical-transformer-base-4096
36
+ <https://huggingface.co/kiddothe2b/hierarchical-transformer-base-4096>`__ architecture.
37
+
38
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
39
+ outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
40
+
41
+
42
+ Args:
43
+ vocab_size (:obj:`int`, `optional`, defaults to 30522):
44
+ Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
45
+ :obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or
46
+ :class:`~transformers.TFBertModel`.
47
+ max_sentences (:obj:`int`, `optional`, defaults to 64):
48
+ The maximum number of sentences that this model might ever be used with.
49
+ max_sentence_size (:obj:`int`, `optional`, defaults to 128):
50
+ The maximum sentence length that this model might ever be used with.
51
+ model_max_length (:obj:`int`, `optional`, defaults to 8192):
52
+ The maximum sequence length (max_sentences * max_sentence_size) that this model might ever be used with
53
+ encoder_layout (:obj:`Dict`):
54
+ The sentence/document encoder layout.
55
+ hidden_size (:obj:`int`, `optional`, defaults to 768):
56
+ Dimensionality of the encoder layers and the pooler layer.
57
+ num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
58
+ Number of hidden layers in the Transformer encoder.
59
+ num_attention_heads (:obj:`int`, `optional`, defaults to 12):
60
+ Number of attention heads for each attention layer in the Transformer encoder.
61
+ intermediate_size (:obj:`int`, `optional`, defaults to 3072):
62
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
63
+ hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
64
+ The non-linear activation function (function or string) in the encoder and pooler. If string,
65
+ :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
66
+ hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
67
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
68
+ attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
69
+ The dropout ratio for the attention probabilities.
70
+ max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
71
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
72
+ just in case (e.g., 512 or 1024 or 2048).
73
+ type_vocab_size (:obj:`int`, `optional`, defaults to 2):
74
+ The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or
75
+ :class:`~transformers.TFBertModel`.
76
+ initializer_range (:obj:`float`, `optional`, defaults to 0.02):
77
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
78
+ layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
79
+ The epsilon used by the layer normalization layers.
80
+ position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
81
+ Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
82
+ :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
83
+ :obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
84
+ <https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
85
+ `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
86
+ <https://arxiv.org/abs/2009.13658>`__.
87
+ use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
88
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
89
+ relevant if ``config.is_decoder=True``.
90
+ classifier_dropout (:obj:`float`, `optional`):
91
+ The dropout ratio for the classification head.
92
+ """
93
+ model_type = "hierarchical-transformer"
94
+
95
+ def __init__(
96
+ self,
97
+ vocab_size=30522,
98
+ hidden_size=768,
99
+ max_sentences=64,
100
+ max_sentence_size=128,
101
+ model_max_length=8192,
102
+ num_hidden_layers=12,
103
+ num_attention_heads=12,
104
+ intermediate_size=3072,
105
+ hidden_act="gelu",
106
+ hidden_dropout_prob=0.1,
107
+ attention_probs_dropout_prob=0.1,
108
+ max_position_embeddings=512,
109
+ type_vocab_size=2,
110
+ initializer_range=0.02,
111
+ layer_norm_eps=1e-12,
112
+ pad_token_id=0,
113
+ position_embedding_type="absolute",
114
+ encoder_layout=None,
115
+ use_cache=True,
116
+ classifier_dropout=None,
117
+ **kwargs
118
+ ):
119
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
120
+
121
+ self.vocab_size = vocab_size
122
+ self.hidden_size = hidden_size
123
+ self.max_sentences = max_sentences
124
+ self.max_sentence_size = max_sentence_size
125
+ self.model_max_length = model_max_length
126
+ self.encoder_layout = encoder_layout
127
+ self.num_hidden_layers = num_hidden_layers
128
+ self.num_attention_heads = num_attention_heads
129
+ self.hidden_act = hidden_act
130
+ self.intermediate_size = intermediate_size
131
+ self.hidden_dropout_prob = hidden_dropout_prob
132
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
133
+ self.max_position_embeddings = max_position_embeddings
134
+ self.type_vocab_size = type_vocab_size
135
+ self.initializer_range = initializer_range
136
+ self.layer_norm_eps = layer_norm_eps
137
+ self.position_embedding_type = position_embedding_type
138
+ self.use_cache = use_cache
139
+ self.classifier_dropout = classifier_dropout
140
+
141
+
142
+ class HATOnnxConfig(OnnxConfig):
143
+ @property
144
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
145
+ return OrderedDict(
146
+ [
147
+ ("input_ids", {0: "batch", 1: "sequence"}),
148
+ ("attention_mask", {0: "batch", 1: "sequence"}),
149
+ ]
150
+ )
modelling_hat.py ADDED
The diff for this file is too large to render. See raw diff
tokenization_hat.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ """Tokenization classes for HAT."""
14
+ import torch
15
+ from transformers import AutoTokenizer
16
+ from .configuration_hat import HATConfig
17
+ from transformers.utils import logging
18
+ try:
19
+ from nltk import sent_tokenize
20
+ except:
21
+ raise Exception('NLTK is not installed! Install it with `pip install nltk`...')
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class HATTokenizer:
26
+ def __init__(self, tokenizer=None):
27
+ self._tokenizer = tokenizer
28
+ self.config = HATConfig.from_pretrained(self._tokenizer.name_or_path)
29
+ self._tokenizer.model_max_length = self.model_max_length
30
+ self.type2id = {'input_ids': (self._tokenizer.cls_token_id, self._tokenizer.pad_token_id),
31
+ 'token_type_ids': (0, 0),
32
+ 'attention_mask': (1, 0),
33
+ 'special_tokens_mask': (1, -100)}
34
+
35
+ @property
36
+ def model_max_length(self):
37
+ return self.config.model_max_length
38
+
39
+ @property
40
+ def mask_token(self):
41
+ return self._tokenizer.mask_token
42
+
43
+ @property
44
+ def mask_token_id(self):
45
+ return self._tokenizer.mask_token_id
46
+
47
+ @property
48
+ def pad_token_id(self):
49
+ return self._tokenizer.pad_token_id
50
+
51
+ @property
52
+ def cls_token_id(self):
53
+ return self._tokenizer.cls_token_id
54
+
55
+ @property
56
+ def sep_token_id(self):
57
+ return self._tokenizer.sep_token_id
58
+
59
+ @property
60
+ def vocab(self):
61
+ return self._tokenizer.vocab
62
+
63
+ def __len__(self):
64
+ """
65
+ Size of the full vocabulary with the added tokens.
66
+ """
67
+ return len(self._tokenizer)
68
+
69
+ def pad(self, *args, **kwargs):
70
+ return self._tokenizer.pad(*args, **kwargs)
71
+
72
+ def convert_tokens_to_ids(self, *args, **kwargs):
73
+ return self._tokenizer.convert_tokens_to_ids(*args, **kwargs)
74
+
75
+ def batch_decode(self, *args, **kwargs):
76
+ return self._tokenizer.batch_decode(*args, **kwargs)
77
+
78
+ def decode(self, *args, **kwargs):
79
+ return self._tokenizer.decode(*args, **kwargs)
80
+
81
+ def tokenize(self, text, **kwargs):
82
+ return self._tokenizer.tokenize(text, **kwargs)
83
+
84
+ def encode(self, text, **kwargs):
85
+ input_ids = self._tokenizer.encode_plus(text, add_special_tokens=False, **kwargs)
86
+ input_ids = self.chunks(input_ids[: self.model_max_length - self.config.max_sentences],
87
+ chunk_size=self.config.max_sentence_length, special_id=self.type2id['input_ids'])
88
+ return input_ids
89
+
90
+ def get_special_tokens_mask(self, *args, **kwargs):
91
+ return self._tokenizer.get_special_tokens_mask(*args, **kwargs)
92
+
93
+ @classmethod
94
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
95
+ return cls(tokenizer=AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs))
96
+
97
+ def save_pretrained(self, *args, **kwargs):
98
+ return self._tokenizer.save_pretrained( *args, **kwargs)
99
+
100
+ def __call__(self, text, **kwargs):
101
+ greedy_chunking = kwargs.pop('greedy_chunking', None)
102
+ text_pair = kwargs.pop('text_pair', None)
103
+ if isinstance(text[0], list):
104
+ batch = self.auto_chunking(text, **kwargs)
105
+ elif greedy_chunking:
106
+ # fixed uniform chunking
107
+ batch = self.uniform_chunking(text, **kwargs)
108
+ else:
109
+ # dynamic sentence splitting and grouping
110
+ batch = self.sentence_splitting(text, **kwargs)
111
+
112
+ if text_pair:
113
+ batch_b = self._tokenizer(text_pair, add_special_tokens=False,
114
+ padding=False, truncation=False)
115
+ for idx, sample in enumerate(batch['input_ids']):
116
+ n_sentences = sum(sample[::self.config.max_sentence_size])
117
+ for input_key in batch:
118
+ batch[input_key][idx][self.config.max_sentence_size * n_sentences:
119
+ self.config.max_sentence_size * (n_sentences + 1)] = \
120
+ self.pad_sentence(batch_b[input_key][idx],
121
+ special_id=(self.sep_token_id, self.pad_token_id)
122
+ if input_key == 'input_ids' else self.type2id[input_key])
123
+
124
+ return batch
125
+
126
+ def uniform_chunking(self, texts, **kwargs):
127
+ original_batch = self._tokenizer(texts, add_special_tokens=False, **kwargs)
128
+ batch = {input_type: [] for input_type in original_batch}
129
+ for input_type in original_batch:
130
+ fixed_batch = []
131
+ for example in original_batch[input_type]:
132
+ fixed_batch.append(self.chunks(example[: self.model_max_length - self.config.max_sentences],
133
+ chunk_size=self.config.max_sentence_length,
134
+ special_id=self.type2id[input_type]))
135
+ batch[input_type] = fixed_batch if isinstance(fixed_batch[0], list) else torch.stack(fixed_batch)
136
+
137
+ if kwargs['padding']:
138
+ batch = self.pad(batch,
139
+ padding=kwargs['padding'],
140
+ max_length=kwargs['max_length'],
141
+ pad_to_multiple_of=kwargs['max_length'])
142
+
143
+ return batch
144
+
145
+ def auto_chunking(self, texts, **kwargs):
146
+ batch = {}
147
+ for text_idx, text in enumerate(texts):
148
+ example_batch = self._tokenizer(text, add_special_tokens=False, **kwargs)
149
+ for input_key in example_batch:
150
+ key_inputs_list = []
151
+ for idx, example in enumerate(example_batch[input_key][:self.config.max_sentences]):
152
+ key_inputs_list.append(self.pad_sentence(example, special_id=self.type2id[input_key]))
153
+ if isinstance(key_inputs_list[0], list):
154
+ key_inputs_list = [token for sentence in key_inputs_list for token in sentence]
155
+ else:
156
+ key_inputs_list = torch.stack(key_inputs_list)
157
+ if input_key in batch:
158
+ batch[input_key].append(key_inputs_list)
159
+ else:
160
+ batch[input_key] = [key_inputs_list]
161
+
162
+ if kwargs['padding']:
163
+ batch = self.pad(batch,
164
+ padding=kwargs['padding'],
165
+ max_length=kwargs['max_length'],
166
+ pad_to_multiple_of=kwargs['max_length'])
167
+
168
+ return batch
169
+
170
+ def chunks(self, flat_inputs, chunk_size=128, special_id=0):
171
+ if isinstance(flat_inputs, list):
172
+ return self.list_chunks(flat_inputs, chunk_size, special_id)
173
+ else:
174
+ return self.tensor_chunks(flat_inputs, chunk_size, special_id)
175
+
176
+ def list_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)):
177
+ """Yield successive n-sized chunks from lst."""
178
+ structured_inputs = [[special_id[0] if sum(flat_inputs[i:i + chunk_size-1]) else special_id[1]]
179
+ + flat_inputs[i:i + chunk_size-1] for i in range(0, len(flat_inputs), chunk_size-1)]
180
+ return [token_input for sentence_inputs in structured_inputs for token_input in sentence_inputs]
181
+
182
+ def tensor_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)):
183
+ """Yield successive n-sized chunks from lst."""
184
+ structured_inputs = torch.stack([torch.cat((torch.tensor([special_id[0] if flat_inputs[i:i + chunk_size-1].sum() else special_id[1]], dtype=torch.int),
185
+ flat_inputs[i:i + chunk_size-1])) for i in range(0, len(flat_inputs), chunk_size-1)])
186
+ return structured_inputs.reshape(-1)
187
+
188
+ def sentence_splitting(self, texts, **kwargs):
189
+ fixed_batch = []
190
+ doc_out = {}
191
+ for text in texts:
192
+ # sentence splitting
193
+ sentences = sent_tokenize(text)
194
+ # tokenization of sentences
195
+ sentences = self._tokenizer(sentences, add_special_tokens=False, padding=False, truncation=False)
196
+ # sentence grouping - merging short sentences to minimize padding
197
+ doc_out = self.sentence_grouping(sentences)
198
+ fixed_batch.append(doc_out)
199
+ # batchify examples
200
+ batch = {input_type: [] for input_type in doc_out}
201
+ for input_type in batch:
202
+ batch[input_type] = [example[input_type] for example in fixed_batch]
203
+ if not isinstance(batch[input_type][0], list):
204
+ batch[input_type] = torch.stack(batch[input_type])
205
+
206
+ if kwargs['padding']:
207
+ batch = self.pad(batch,
208
+ padding=kwargs['padding'],
209
+ max_length=kwargs['max_length'],
210
+ pad_to_multiple_of=kwargs['max_length'])
211
+
212
+ return batch
213
+
214
+ def sentence_grouping(self, sentences):
215
+ doc_out = {input_type: [] for input_type in sentences}
216
+ for input_type in sentences:
217
+ tmp_doc = []
218
+ tmp_sentence = []
219
+ for example in sentences[input_type]:
220
+ if len(tmp_doc) >= self.config.max_sentences:
221
+ break
222
+ if len(tmp_sentence) + len(example) <= self.config.max_sentence_length - 1:
223
+ tmp_sentence.extend(example)
224
+ else:
225
+ tmp_doc.append(self.pad_sentence(tmp_sentence if len(tmp_sentence) else example,
226
+ chunk_size=self.config.max_sentence_length,
227
+ special_id=self.type2id[input_type]))
228
+ tmp_sentence = example if len(tmp_sentence) else example[self.config.max_sentence_length:]
229
+ if len(tmp_sentence) and len(tmp_doc) < self.config.max_sentences:
230
+ tmp_doc.append(self.pad_sentence(tmp_sentence,
231
+ chunk_size=self.config.max_sentence_length,
232
+ special_id=self.type2id[input_type]))
233
+ doc_out[input_type] = [token for sentence in tmp_doc for token in sentence]
234
+ return doc_out
235
+
236
+ def pad_sentence(self, flat_input, chunk_size=128, special_id=(0, 0)):
237
+ if isinstance(flat_input, list):
238
+ return [special_id[0]] + flat_input[:chunk_size-1] + [self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1)
239
+ else:
240
+ return torch.cat((torch.tensor([special_id[0] if flat_input[:chunk_size-1].sum()
241
+ else special_id[1]], dtype=torch.int),
242
+ flat_input[:chunk_size-1],
243
+ torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
244
+ ))