jackstanley commited on
Commit
5643fe7
1 Parent(s): 5e8afc1

Upload 11 files

Browse files
README.md ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-sa-4.0
3
+ pipeline_tag: fill-mask
4
+ arxiv: 2210.05529
5
+ language: en
6
+ thumbnail: https://github.com/coastalcph/hierarchical-transformers/raw/main/data/figures/hat_encoder.png
7
+ tags:
8
+ - long-documents
9
+ datasets:
10
+ - c4
11
+ model-index:
12
+ - name: kiddothe2b/hierarchical-transformer-base-4096
13
+ results: []
14
+ ---
15
+
16
+ # Hierarchical Attention Transformer (HAT) / hierarchical-transformer-base-4096
17
+
18
+ ## Model description
19
+
20
+ This is a Hierarchical Attention Transformer (HAT) model as presented in [An Exploration of Hierarchical Attention Transformers for Efficient Long Document Classification (Chalkidis et al., 2022)](https://arxiv.org/abs/2210.05529).
21
+
22
+ The model has been warm-started re-using the weights of RoBERTa (Liu et al., 2019), and continued pre-trained for MLM in long sequences following the paradigm of Longformer released by Beltagy et al. (2020). It supports sequences of length up to 4,096.
23
+
24
+ HAT uses hierarchical attention, which is a combination of segment-wise and cross-segment attention operations. You can think of segments as paragraphs or sentences.
25
+
26
+ ## Intended uses & limitations
27
+
28
+ You can use the raw model for masked language modeling, but it's mostly intended to be fine-tuned on a downstream task.
29
+ See the [model hub](https://huggingface.co/models?filter=hierarchical-transformer) to look for other versions of HAT or fine-tuned versions on a task that interests you.
30
+
31
+ Note that this model is primarily aimed at being fine-tuned on tasks that use the whole document to make decisions, such as document classification, sequential sentence classification, or question answering.
32
+
33
+ ## How to use
34
+
35
+ You can use this model directly for masked language modeling:
36
+
37
+ ```python
38
+ from transformers import AutoTokenizer, AutoModelforForMaskedLM
39
+ tokenizer = AutoTokenizer.from_pretrained("kiddothe2b/hierarchical-transformer-base-4096", trust_remote_code=True)
40
+ mlm_model = AutoModelforForMaskedLM("kiddothe2b/hierarchical-transformer-base-4096", trust_remote_code=True)
41
+ ```
42
+
43
+ You can also fine-tune it for SequenceClassification, SequentialSentenceClassification, and MultipleChoice down-stream tasks:
44
+
45
+ ```python
46
+ from transformers import AutoTokenizer, AutoModelforSequenceClassification
47
+ tokenizer = AutoTokenizer.from_pretrained("kiddothe2b/hierarchical-transformer-base-4096", trust_remote_code=True)
48
+ doc_classifier = AutoModelforSequenceClassification("kiddothe2b/hierarchical-transformer-base-4096", trust_remote_code=True)
49
+ ```
50
+
51
+ ## Limitations and bias
52
+
53
+ The training data used for this model contains a lot of unfiltered content from the internet, which is far from
54
+ neutral. Therefore, the model can have biased predictions.
55
+
56
+
57
+ ## Training procedure
58
+
59
+ ### Training and evaluation data
60
+
61
+ The model has been warm-started from [roberta-base](https://huggingface.co/roberta-base) checkpoint and has been continued pre-trained for additional 50k steps in long sequences (> 1024 subwords) of [C4](https://huggingface.co/datasets/c4) (Raffel et al., 2020).
62
+
63
+
64
+ ### Training hyperparameters
65
+
66
+ The following hyperparameters were used during training:
67
+ - learning_rate: 0.0001
68
+ - train_batch_size: 2
69
+ - eval_batch_size: 2
70
+ - seed: 42
71
+ - distributed_type: tpu
72
+ - num_devices: 8
73
+ - gradient_accumulation_steps: 8
74
+ - total_train_batch_size: 128
75
+ - total_eval_batch_size: 16
76
+ - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
77
+ - lr_scheduler_type: linear
78
+ - lr_scheduler_warmup_ratio: 0.1
79
+ - training_steps: 50000
80
+
81
+ ### Training results
82
+
83
+ | Training Loss | Epoch | Step | Validation Loss |
84
+ |:-------------:|:-----:|:-----:|:---------------:|
85
+ | 1.7437 | 0.2 | 10000 | 1.6370 |
86
+ | 1.6994 | 0.4 | 20000 | 1.6054 |
87
+ | 1.6726 | 0.6 | 30000 | 1.5718 |
88
+ | 1.644 | 0.8 | 40000 | 1.5526 |
89
+ | 1.6299 | 1.0 | 50000 | 1.5368 |
90
+
91
+
92
+ ### Framework versions
93
+
94
+ - Transformers 4.19.0.dev0
95
+ - Pytorch 1.11.0+cu102
96
+ - Datasets 2.0.0
97
+ - Tokenizers 0.11.6
98
+
99
+
100
+ ## Citing
101
+
102
+ If you use HAT in your research, please cite:
103
+
104
+ [An Exploration of Hierarchical Attention Transformers for Efficient Long Document Classification](https://arxiv.org/abs/2210.05529). Ilias Chalkidis, Xiang Dai, Manos Fergadiotis, Prodromos Malakasiotis, and Desmond Elliott. 2022. arXiv:2210.05529 (Preprint).
105
+
106
+ ```
107
+ @misc{chalkidis-etal-2022-hat,
108
+ url = {https://arxiv.org/abs/2210.05529},
109
+ author = {Chalkidis, Ilias and Dai, Xiang and Fergadiotis, Manos and Malakasiotis, Prodromos and Elliott, Desmond},
110
+ title = {An Exploration of Hierarchical Attention Transformers for Efficient Long Document Classification},
111
+ publisher = {arXiv},
112
+ year = {2022},
113
+ }
114
+ ```
115
+
116
+
config.json ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "kiddothe2b/hierarchical-transformer-base-4096",
3
+ "architectures": [
4
+ "HATForMaskedLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_hat.HATConfig",
8
+ "AutoTokenizer": "tokenization_hat.HATTokenizer",
9
+ "AutoModel": "modelling_hat.HATModel",
10
+ "AutoModelForMaskedLM": "modelling_hat.HATForMaskedLM",
11
+ "AutoModelForMultipleChoice": "modelling_hat.HATForMultipleChoice",
12
+ "AutoModelForQuestionAnswering": "modelling_hat.HATForQuestionAnswering",
13
+ "AutoModelForSequenceClassification": "modelling_hat.HATForSequenceClassification",
14
+ "AutoModelForTokenClassification": "modelling_hat.HATForTokenClassification"
15
+ },
16
+ "attention_probs_dropout_prob": 0.1,
17
+ "bos_token_id": 0,
18
+ "classifier_dropout": null,
19
+ "encoder_layout": {
20
+ "0": {
21
+ "document_encoder": false,
22
+ "sentence_encoder": true
23
+ },
24
+ "1": {
25
+ "document_encoder": false,
26
+ "sentence_encoder": true
27
+ },
28
+ "10": {
29
+ "document_encoder": false,
30
+ "sentence_encoder": true
31
+ },
32
+ "11": {
33
+ "document_encoder": true,
34
+ "sentence_encoder": true
35
+ },
36
+ "2": {
37
+ "document_encoder": true,
38
+ "sentence_encoder": true
39
+ },
40
+ "3": {
41
+ "document_encoder": false,
42
+ "sentence_encoder": true
43
+ },
44
+ "4": {
45
+ "document_encoder": false,
46
+ "sentence_encoder": true
47
+ },
48
+ "5": {
49
+ "document_encoder": true,
50
+ "sentence_encoder": true
51
+ },
52
+ "6": {
53
+ "document_encoder": false,
54
+ "sentence_encoder": true
55
+ },
56
+ "7": {
57
+ "document_encoder": false,
58
+ "sentence_encoder": true
59
+ },
60
+ "8": {
61
+ "document_encoder": true,
62
+ "sentence_encoder": true
63
+ },
64
+ "9": {
65
+ "document_encoder": false,
66
+ "sentence_encoder": true
67
+ }
68
+ },
69
+ "eos_token_id": 2,
70
+ "hidden_act": "gelu",
71
+ "hidden_dropout_prob": 0.1,
72
+ "hidden_size": 768,
73
+ "initializer_range": 0.02,
74
+ "intermediate_size": 3072,
75
+ "layer_norm_eps": 1e-12,
76
+ "max_position_embeddings": 130,
77
+ "max_sentence_length": 128,
78
+ "max_sentence_size": 128,
79
+ "max_sentences": 32,
80
+ "model_max_length": 4096,
81
+ "model_type": "hierarchical-transformer",
82
+ "num_attention_heads": 12,
83
+ "num_hidden_layers": 12,
84
+ "output_past": true,
85
+ "pad_token_id": 1,
86
+ "parameters": 136350720,
87
+ "position_embedding_type": "absolute",
88
+ "torch_dtype": "float32",
89
+ "transformers_version": "4.19.0.dev0",
90
+ "type_vocab_size": 1,
91
+ "use_cache": true,
92
+ "vocab_size": 50265
93
+ }
configuration_hat.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ """ HAT configuration"""
14
+ from collections import OrderedDict
15
+ from typing import Mapping
16
+
17
+ from transformers.onnx import OnnxConfig
18
+ from transformers.utils import logging
19
+ from transformers import PretrainedConfig
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ HAT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
25
+ "kiddothe2b/hierarchical-transformer-base-4096": "https://huggingface.co/kiddothe2b/hierarchical-transformer-base-4096/resolve/main/config.json",
26
+ "kiddothe2b/adhoc-hierarchical-transformer-base-4096": "https://huggingface.co/kiddothe2b/adhoc-hierarchical-transformer-base-4096/resolve/main/config.json",
27
+ }
28
+
29
+
30
+ class HATConfig(PretrainedConfig):
31
+ r"""
32
+ This is the configuration class to store the configuration of a :class:`~transformers.HAT`.
33
+ It is used to instantiate a HAT model according to the specified arguments,
34
+ defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
35
+ to that of the HAT `kiddothe2b/hierarchical-transformer-base-4096
36
+ <https://huggingface.co/kiddothe2b/hierarchical-transformer-base-4096>`__ architecture.
37
+
38
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
39
+ outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
40
+
41
+
42
+ Args:
43
+ vocab_size (:obj:`int`, `optional`, defaults to 30522):
44
+ Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
45
+ :obj:`inputs_ids` passed when calling :class:`~transformers.BertModel` or
46
+ :class:`~transformers.TFBertModel`.
47
+ max_sentences (:obj:`int`, `optional`, defaults to 64):
48
+ The maximum number of sentences that this model might ever be used with.
49
+ max_sentence_size (:obj:`int`, `optional`, defaults to 128):
50
+ The maximum sentence length that this model might ever be used with.
51
+ model_max_length (:obj:`int`, `optional`, defaults to 8192):
52
+ The maximum sequence length (max_sentences * max_sentence_size) that this model might ever be used with
53
+ encoder_layout (:obj:`Dict`):
54
+ The sentence/document encoder layout.
55
+ hidden_size (:obj:`int`, `optional`, defaults to 768):
56
+ Dimensionality of the encoder layers and the pooler layer.
57
+ num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
58
+ Number of hidden layers in the Transformer encoder.
59
+ num_attention_heads (:obj:`int`, `optional`, defaults to 12):
60
+ Number of attention heads for each attention layer in the Transformer encoder.
61
+ intermediate_size (:obj:`int`, `optional`, defaults to 3072):
62
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
63
+ hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
64
+ The non-linear activation function (function or string) in the encoder and pooler. If string,
65
+ :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
66
+ hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
67
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
68
+ attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
69
+ The dropout ratio for the attention probabilities.
70
+ max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
71
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
72
+ just in case (e.g., 512 or 1024 or 2048).
73
+ type_vocab_size (:obj:`int`, `optional`, defaults to 2):
74
+ The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.BertModel` or
75
+ :class:`~transformers.TFBertModel`.
76
+ initializer_range (:obj:`float`, `optional`, defaults to 0.02):
77
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
78
+ layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
79
+ The epsilon used by the layer normalization layers.
80
+ position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
81
+ Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
82
+ :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
83
+ :obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
84
+ <https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
85
+ `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
86
+ <https://arxiv.org/abs/2009.13658>`__.
87
+ use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
88
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
89
+ relevant if ``config.is_decoder=True``.
90
+ classifier_dropout (:obj:`float`, `optional`):
91
+ The dropout ratio for the classification head.
92
+ """
93
+ model_type = "hierarchical-transformer"
94
+
95
+ def __init__(
96
+ self,
97
+ vocab_size=30522,
98
+ hidden_size=768,
99
+ max_sentences=64,
100
+ max_sentence_size=128,
101
+ model_max_length=8192,
102
+ num_hidden_layers=12,
103
+ num_attention_heads=12,
104
+ intermediate_size=3072,
105
+ hidden_act="gelu",
106
+ hidden_dropout_prob=0.1,
107
+ attention_probs_dropout_prob=0.1,
108
+ max_position_embeddings=512,
109
+ type_vocab_size=2,
110
+ initializer_range=0.02,
111
+ layer_norm_eps=1e-12,
112
+ pad_token_id=0,
113
+ position_embedding_type="absolute",
114
+ encoder_layout=None,
115
+ use_cache=True,
116
+ classifier_dropout=None,
117
+ **kwargs
118
+ ):
119
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
120
+
121
+ self.vocab_size = vocab_size
122
+ self.hidden_size = hidden_size
123
+ self.max_sentences = max_sentences
124
+ self.max_sentence_size = max_sentence_size
125
+ self.model_max_length = model_max_length
126
+ self.encoder_layout = encoder_layout
127
+ self.num_hidden_layers = num_hidden_layers
128
+ self.num_attention_heads = num_attention_heads
129
+ self.hidden_act = hidden_act
130
+ self.intermediate_size = intermediate_size
131
+ self.hidden_dropout_prob = hidden_dropout_prob
132
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
133
+ self.max_position_embeddings = max_position_embeddings
134
+ self.type_vocab_size = type_vocab_size
135
+ self.initializer_range = initializer_range
136
+ self.layer_norm_eps = layer_norm_eps
137
+ self.position_embedding_type = position_embedding_type
138
+ self.use_cache = use_cache
139
+ self.classifier_dropout = classifier_dropout
140
+
141
+
142
+ class HATOnnxConfig(OnnxConfig):
143
+ @property
144
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
145
+ return OrderedDict(
146
+ [
147
+ ("input_ids", {0: "batch", 1: "sequence"}),
148
+ ("attention_mask", {0: "batch", 1: "sequence"}),
149
+ ]
150
+ )
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
modelling_hat.py ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
tokenization_hat.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ """Tokenization classes for HAT."""
14
+ import torch
15
+ from transformers import RobertaTokenizer, BertTokenizer
16
+ from .configuration_hat import HATConfig
17
+ from transformers.utils import logging
18
+ try:
19
+ from nltk import sent_tokenize
20
+ except:
21
+ raise Exception('NLTK is not installed! Install it with `pip install nltk`...')
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class HATTokenizer:
26
+ def __init__(self, tokenizer=None):
27
+ self._tokenizer = tokenizer
28
+ self.config = HATConfig.from_pretrained(self._tokenizer.name_or_path)
29
+ self._tokenizer.model_max_length = self.model_max_length
30
+ self.type2id = {'input_ids': (self._tokenizer.cls_token_id, self._tokenizer.pad_token_id),
31
+ 'token_type_ids': (0, 0),
32
+ 'attention_mask': (1, 0),
33
+ 'special_tokens_mask': (1, -100)}
34
+
35
+ @property
36
+ def model_max_length(self):
37
+ return self.config.model_max_length
38
+
39
+ @property
40
+ def mask_token(self):
41
+ return self._tokenizer.mask_token
42
+
43
+ @property
44
+ def mask_token_id(self):
45
+ return self._tokenizer.mask_token_id
46
+
47
+ @property
48
+ def pad_token_id(self):
49
+ return self._tokenizer.pad_token_id
50
+
51
+ @property
52
+ def cls_token_id(self):
53
+ return self._tokenizer.cls_token_id
54
+
55
+ @property
56
+ def sep_token_id(self):
57
+ return self._tokenizer.sep_token_id
58
+
59
+ @property
60
+ def vocab(self):
61
+ return self._tokenizer.vocab
62
+
63
+ def __len__(self):
64
+ """
65
+ Size of the full vocabulary with the added tokens.
66
+ """
67
+ return len(self._tokenizer)
68
+
69
+ def pad(self, *args, **kwargs):
70
+ return self._tokenizer.pad(*args, **kwargs)
71
+
72
+ def convert_tokens_to_ids(self, *args, **kwargs):
73
+ return self._tokenizer.convert_tokens_to_ids(*args, **kwargs)
74
+
75
+ def batch_decode(self, *args, **kwargs):
76
+ return self._tokenizer.batch_decode(*args, **kwargs)
77
+
78
+ def decode(self, *args, **kwargs):
79
+ return self._tokenizer.decode(*args, **kwargs)
80
+
81
+ def tokenize(self, text, **kwargs):
82
+ return self._tokenizer.tokenize(text, **kwargs)
83
+
84
+ def encode(self, text, **kwargs):
85
+ input_ids = self._tokenizer.encode_plus(text, add_special_tokens=False, **kwargs)
86
+ input_ids = self.chunks(input_ids[: self.model_max_length - self.config.max_sentences],
87
+ chunk_size=self.config.max_sentence_length, special_id=self.type2id['input_ids'])
88
+ return input_ids
89
+
90
+ def get_special_tokens_mask(self, *args, **kwargs):
91
+ return self._tokenizer.get_special_tokens_mask(*args, **kwargs)
92
+
93
+ @classmethod
94
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
95
+ try:
96
+ tokenizer = RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
97
+ except:
98
+ tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
99
+ return cls(tokenizer=tokenizer)
100
+
101
+ def save_pretrained(self, *args, **kwargs):
102
+ return self._tokenizer.save_pretrained( *args, **kwargs)
103
+
104
+ def __call__(self, text, **kwargs):
105
+ greedy_chunking = kwargs.pop('greedy_chunking', None)
106
+ text_pair = kwargs.pop('text_pair', None)
107
+ if isinstance(text[0], list):
108
+ batch = self.auto_chunking(text, **kwargs)
109
+ elif greedy_chunking:
110
+ # fixed uniform chunking
111
+ batch = self.uniform_chunking(text, **kwargs)
112
+ else:
113
+ # dynamic sentence splitting and grouping
114
+ batch = self.sentence_splitting(text, **kwargs)
115
+
116
+ if text_pair:
117
+ batch_b = self._tokenizer(text_pair, add_special_tokens=False,
118
+ padding=False, truncation=False)
119
+ for idx, sample in enumerate(batch['input_ids']):
120
+ n_sentences = sum(sample[::self.config.max_sentence_size])
121
+ for input_key in batch:
122
+ batch[input_key][idx][self.config.max_sentence_size * n_sentences:
123
+ self.config.max_sentence_size * (n_sentences + 1)] = \
124
+ self.pad_sentence(batch_b[input_key][idx],
125
+ special_id=(self.sep_token_id, self.pad_token_id)
126
+ if input_key == 'input_ids' else self.type2id[input_key])
127
+
128
+ return batch
129
+
130
+ def uniform_chunking(self, texts, **kwargs):
131
+ original_batch = self._tokenizer(texts, add_special_tokens=False, **kwargs)
132
+ batch = {input_type: [] for input_type in original_batch}
133
+ for input_type in original_batch:
134
+ fixed_batch = []
135
+ for example in original_batch[input_type]:
136
+ fixed_batch.append(self.chunks(example[: self.model_max_length - self.config.max_sentences],
137
+ chunk_size=self.config.max_sentence_length,
138
+ special_id=self.type2id[input_type]))
139
+ batch[input_type] = fixed_batch if isinstance(fixed_batch[0], list) else torch.stack(fixed_batch)
140
+
141
+ if kwargs['padding']:
142
+ batch = self.pad(batch,
143
+ padding=kwargs['padding'],
144
+ max_length=kwargs['max_length'],
145
+ pad_to_multiple_of=kwargs['max_length'])
146
+
147
+ return batch
148
+
149
+ def auto_chunking(self, texts, **kwargs):
150
+ batch = {}
151
+ for text_idx, text in enumerate(texts):
152
+ example_batch = self._tokenizer(text, add_special_tokens=False, **kwargs)
153
+ for input_key in example_batch:
154
+ key_inputs_list = []
155
+ for idx, example in enumerate(example_batch[input_key][:self.config.max_sentences]):
156
+ key_inputs_list.append(self.pad_sentence(example, special_id=self.type2id[input_key]))
157
+ if isinstance(key_inputs_list[0], list):
158
+ key_inputs_list = [token for sentence in key_inputs_list for token in sentence]
159
+ else:
160
+ key_inputs_list = torch.stack(key_inputs_list)
161
+ if input_key in batch:
162
+ batch[input_key].append(key_inputs_list)
163
+ else:
164
+ batch[input_key] = [key_inputs_list]
165
+
166
+ if kwargs['padding']:
167
+ batch = self.pad(batch,
168
+ padding=kwargs['padding'],
169
+ max_length=kwargs['max_length'],
170
+ pad_to_multiple_of=kwargs['max_length'])
171
+
172
+ return batch
173
+
174
+ def chunks(self, flat_inputs, chunk_size=128, special_id=0):
175
+ if isinstance(flat_inputs, list):
176
+ return self.list_chunks(flat_inputs, chunk_size, special_id)
177
+ else:
178
+ return self.tensor_chunks(flat_inputs, chunk_size, special_id)
179
+
180
+ def list_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)):
181
+ """Yield successive n-sized chunks from lst."""
182
+ structured_inputs = [[special_id[0] if sum(flat_inputs[i:i + chunk_size-1]) else special_id[1]]
183
+ + flat_inputs[i:i + chunk_size-1] for i in range(0, len(flat_inputs), chunk_size-1)]
184
+ return [token_input for sentence_inputs in structured_inputs for token_input in sentence_inputs]
185
+
186
+ def tensor_chunks(self, flat_inputs, chunk_size=128, special_id=(0, 0)):
187
+ """Yield successive n-sized chunks from lst."""
188
+ structured_inputs = torch.stack([torch.cat((torch.tensor([special_id[0] if flat_inputs[i:i + chunk_size-1].sum() else special_id[1]], dtype=torch.int),
189
+ flat_inputs[i:i + chunk_size-1])) for i in range(0, len(flat_inputs), chunk_size-1)])
190
+ return structured_inputs.reshape(-1)
191
+
192
+ def sentence_splitting(self, texts, **kwargs):
193
+ fixed_batch = []
194
+ doc_out = {}
195
+ for text in texts:
196
+ # sentence splitting
197
+ sentences = sent_tokenize(text)
198
+ # tokenization of sentences
199
+ sentences = self._tokenizer(sentences, add_special_tokens=False, padding=False, truncation=False)
200
+ # sentence grouping - merging short sentences to minimize padding
201
+ doc_out = self.sentence_grouping(sentences)
202
+ fixed_batch.append(doc_out)
203
+ # batchify examples
204
+ batch = {input_type: [] for input_type in doc_out}
205
+ for input_type in batch:
206
+ batch[input_type] = [example[input_type] for example in fixed_batch]
207
+ if not isinstance(batch[input_type][0], list):
208
+ batch[input_type] = torch.stack(batch[input_type])
209
+
210
+ if kwargs['padding']:
211
+ batch = self.pad(batch,
212
+ padding=kwargs['padding'],
213
+ max_length=kwargs['max_length'],
214
+ pad_to_multiple_of=kwargs['max_length'])
215
+
216
+ return batch
217
+
218
+ def sentence_grouping(self, sentences):
219
+ doc_out = {input_type: [] for input_type in sentences}
220
+ for input_type in sentences:
221
+ tmp_doc = []
222
+ tmp_sentence = []
223
+ for example in sentences[input_type]:
224
+ if len(tmp_doc) >= self.config.max_sentences:
225
+ break
226
+ if len(tmp_sentence) + len(example) <= self.config.max_sentence_length - 1:
227
+ tmp_sentence.extend(example)
228
+ else:
229
+ tmp_doc.append(self.pad_sentence(tmp_sentence if len(tmp_sentence) else example,
230
+ chunk_size=self.config.max_sentence_length,
231
+ special_id=self.type2id[input_type]))
232
+ tmp_sentence = example if len(tmp_sentence) else example[self.config.max_sentence_length:]
233
+ if len(tmp_sentence) and len(tmp_doc) < self.config.max_sentences:
234
+ tmp_doc.append(self.pad_sentence(tmp_sentence,
235
+ chunk_size=self.config.max_sentence_length,
236
+ special_id=self.type2id[input_type]))
237
+ doc_out[input_type] = [token for sentence in tmp_doc for token in sentence]
238
+ return doc_out
239
+
240
+ def pad_sentence(self, flat_input, chunk_size=128, special_id=(0, 0)):
241
+ if isinstance(flat_input, list):
242
+ return [special_id[0]] + flat_input[:chunk_size-1] + [self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1)
243
+ else:
244
+ return torch.cat((torch.tensor([special_id[0] if flat_input[:chunk_size-1].sum()
245
+ else special_id[1]], dtype=torch.int),
246
+ flat_input[:chunk_size-1],
247
+ torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
248
+ ))
249
+
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
vocab.json ADDED
The diff for this file is too large to render. See raw diff