TibetanAI commited on
Commit
080f8e6
1 Parent(s): 6583286

Upload 36 files

Browse files
Files changed (36) hide show
  1. model/__init__.py +1 -0
  2. model/__pycache__/__init__.cpython-310.pyc +0 -0
  3. model/__pycache__/__init__.cpython-37.pyc +0 -0
  4. model/__pycache__/__init__.cpython-38.pyc +0 -0
  5. model/__pycache__/configuration_albert.cpython-37.pyc +0 -0
  6. model/__pycache__/configuration_albert.cpython-38.pyc +0 -0
  7. model/__pycache__/configuration_utils.cpython-37.pyc +0 -0
  8. model/__pycache__/configuration_utils.cpython-38.pyc +0 -0
  9. model/__pycache__/file_utils.cpython-37.pyc +0 -0
  10. model/__pycache__/file_utils.cpython-38.pyc +0 -0
  11. model/__pycache__/modeling_albert.cpython-37.pyc +0 -0
  12. model/__pycache__/modeling_albert.cpython-38.pyc +0 -0
  13. model/__pycache__/modeling_utils.cpython-37.pyc +0 -0
  14. model/__pycache__/modeling_utils.cpython-38.pyc +0 -0
  15. model/__pycache__/tokenization_albert.cpython-310.pyc +0 -0
  16. model/__pycache__/tokenization_albert.cpython-37.pyc +0 -0
  17. model/__pycache__/tokenization_albert.cpython-38.pyc +0 -0
  18. model/__pycache__/tokenization_bert.cpython-37.pyc +0 -0
  19. model/__pycache__/tokenization_utils.cpython-37.pyc +0 -0
  20. model/configuration_albert.py +79 -0
  21. model/configuration_bert.py +83 -0
  22. model/configuration_utils.py +206 -0
  23. model/file_utils.py +294 -0
  24. model/modeling_albert.py +1088 -0
  25. model/modeling_albert_bright.py +1002 -0
  26. model/modeling_bert.py +1149 -0
  27. model/modeling_utils.py +756 -0
  28. model/tokenization_albert.py +358 -0
  29. model/tokenization_bert.py +441 -0
  30. model/tokenization_utils.py +1065 -0
  31. test.py +16 -0
  32. tibetan-albert-syllable-base/config.json +31 -0
  33. tibetan-albert-syllable-base/optimizer.bin +3 -0
  34. tibetan-albert-syllable-base/pytorch_model.bin +3 -0
  35. tibetan-albert-syllable-base/training_args.bin +3 -0
  36. tibetan-albert-syllable-base/vocab.txt +0 -0
model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #encoding:utf-8
model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (146 Bytes). View file
 
model/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (140 Bytes). View file
 
model/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (143 Bytes). View file
 
model/__pycache__/configuration_albert.cpython-37.pyc ADDED
Binary file (3.45 kB). View file
 
model/__pycache__/configuration_albert.cpython-38.pyc ADDED
Binary file (3.46 kB). View file
 
model/__pycache__/configuration_utils.cpython-37.pyc ADDED
Binary file (9.34 kB). View file
 
model/__pycache__/configuration_utils.cpython-38.pyc ADDED
Binary file (9.39 kB). View file
 
model/__pycache__/file_utils.cpython-37.pyc ADDED
Binary file (8.49 kB). View file
 
model/__pycache__/file_utils.cpython-38.pyc ADDED
Binary file (8.41 kB). View file
 
model/__pycache__/modeling_albert.cpython-37.pyc ADDED
Binary file (49.9 kB). View file
 
model/__pycache__/modeling_albert.cpython-38.pyc ADDED
Binary file (49.1 kB). View file
 
model/__pycache__/modeling_utils.cpython-37.pyc ADDED
Binary file (31.6 kB). View file
 
model/__pycache__/modeling_utils.cpython-38.pyc ADDED
Binary file (31.5 kB). View file
 
model/__pycache__/tokenization_albert.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
model/__pycache__/tokenization_albert.cpython-37.pyc ADDED
Binary file (10.1 kB). View file
 
model/__pycache__/tokenization_albert.cpython-38.pyc ADDED
Binary file (10.2 kB). View file
 
model/__pycache__/tokenization_bert.cpython-37.pyc ADDED
Binary file (15.2 kB). View file
 
model/__pycache__/tokenization_utils.cpython-37.pyc ADDED
Binary file (45 kB). View file
 
model/configuration_albert.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ BERT model configuration """
2
+ from __future__ import absolute_import, division, print_function, unicode_literals
3
+
4
+ import json
5
+ import logging
6
+ import sys
7
+ from io import open
8
+
9
+ from .configuration_utils import PretrainedConfig
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class AlbertConfig(PretrainedConfig):
13
+ r"""
14
+ Arguments:
15
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
16
+ hidden_size: Size of the encoder layers and the pooler layer.
17
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
18
+ num_attention_heads: Number of attention heads for each attention layer in
19
+ the Transformer encoder.
20
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
21
+ layer in the Transformer encoder.
22
+ hidden_act: The non-linear activation function (function or string) in the
23
+ encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
24
+ hidden_dropout_prob: The dropout probabilitiy for all fully connected
25
+ layers in the embeddings, encoder, and pooler.
26
+ attention_probs_dropout_prob: The dropout ratio for the attention
27
+ probabilities.
28
+ max_position_embeddings: The maximum sequence length that this model might
29
+ ever be used with. Typically set this to something large just in case
30
+ (e.g., 512 or 1024 or 2048).
31
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
32
+ `BertModel`.
33
+ initializer_range: The sttdev of the truncated_normal_initializer for
34
+ initializing all weight matrices.
35
+ layer_norm_eps: The epsilon used by LayerNorm.
36
+ """
37
+ def __init__(self,
38
+ vocab_size_or_config_json_file=30000,
39
+ embedding_size=128,
40
+ hidden_size=4096,
41
+ num_hidden_layers=12,
42
+ num_hidden_groups=1,
43
+ num_attention_heads=64,
44
+ intermediate_size=16384,
45
+ inner_group_num=1,
46
+ hidden_act="gelu_new",
47
+ hidden_dropout_prob=0,
48
+ attention_probs_dropout_prob=0,
49
+ max_position_embeddings=512,
50
+ type_vocab_size=2,
51
+ initializer_range=0.02,
52
+ layer_norm_eps=1e-12,
53
+ **kwargs):
54
+ super(AlbertConfig, self).__init__(**kwargs)
55
+ if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
56
+ and isinstance(vocab_size_or_config_json_file, unicode)):
57
+ with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
58
+ json_config = json.loads(reader.read())
59
+ for key, value in json_config.items():
60
+ self.__dict__[key] = value
61
+ elif isinstance(vocab_size_or_config_json_file, int):
62
+ self.vocab_size = vocab_size_or_config_json_file
63
+ self.hidden_size = hidden_size
64
+ self.num_hidden_layers = num_hidden_layers
65
+ self.num_attention_heads = num_attention_heads
66
+ self.hidden_act = hidden_act
67
+ self.intermediate_size = intermediate_size
68
+ self.hidden_dropout_prob = hidden_dropout_prob
69
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
70
+ self.max_position_embeddings = max_position_embeddings
71
+ self.type_vocab_size = type_vocab_size
72
+ self.initializer_range = initializer_range
73
+ self.layer_norm_eps = layer_norm_eps
74
+ self.embedding_size = embedding_size
75
+ self.inner_group_num = inner_group_num
76
+ self.num_hidden_groups = num_hidden_groups
77
+ else:
78
+ raise ValueError("First argument must be either a vocabulary size (int)"
79
+ " or the path to a pretrained model config file (str)")
model/configuration_bert.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """ BERT model configuration """
3
+
4
+ from __future__ import absolute_import, division, print_function, unicode_literals
5
+
6
+ import json
7
+ import logging
8
+ import sys
9
+ from io import open
10
+
11
+ from .configuration_utils import PretrainedConfig
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
16
+ class BertConfig(PretrainedConfig):
17
+ r"""
18
+ :class:`~pytorch_transformers.BertConfig` is the configuration class to store the configuration of a
19
+ `BertModel`.
20
+
21
+
22
+ Arguments:
23
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
24
+ hidden_size: Size of the encoder layers and the pooler layer.
25
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
26
+ num_attention_heads: Number of attention heads for each attention layer in
27
+ the Transformer encoder.
28
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
29
+ layer in the Transformer encoder.
30
+ hidden_act: The non-linear activation function (function or string) in the
31
+ encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
32
+ hidden_dropout_prob: The dropout probabilitiy for all fully connected
33
+ layers in the embeddings, encoder, and pooler.
34
+ attention_probs_dropout_prob: The dropout ratio for the attention
35
+ probabilities.
36
+ max_position_embeddings: The maximum sequence length that this model might
37
+ ever be used with. Typically set this to something large just in case
38
+ (e.g., 512 or 1024 or 2048).
39
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
40
+ `BertModel`.
41
+ initializer_range: The sttdev of the truncated_normal_initializer for
42
+ initializing all weight matrices.
43
+ layer_norm_eps: The epsilon used by LayerNorm.
44
+ """
45
+ pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
46
+
47
+ def __init__(self,
48
+ vocab_size_or_config_json_file=30522,
49
+ hidden_size=768,
50
+ num_hidden_layers=12,
51
+ num_attention_heads=12,
52
+ intermediate_size=3072,
53
+ hidden_act="gelu",
54
+ hidden_dropout_prob=0.1,
55
+ attention_probs_dropout_prob=0.1,
56
+ max_position_embeddings=512,
57
+ type_vocab_size=2,
58
+ initializer_range=0.02,
59
+ layer_norm_eps=1e-12,
60
+ **kwargs):
61
+ super(BertConfig, self).__init__(**kwargs)
62
+ if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
63
+ and isinstance(vocab_size_or_config_json_file, unicode)):
64
+ with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
65
+ json_config = json.loads(reader.read())
66
+ for key, value in json_config.items():
67
+ self.__dict__[key] = value
68
+ elif isinstance(vocab_size_or_config_json_file, int):
69
+ self.vocab_size = vocab_size_or_config_json_file
70
+ self.hidden_size = hidden_size
71
+ self.num_hidden_layers = num_hidden_layers
72
+ self.num_attention_heads = num_attention_heads
73
+ self.hidden_act = hidden_act
74
+ self.intermediate_size = intermediate_size
75
+ self.hidden_dropout_prob = hidden_dropout_prob
76
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
77
+ self.max_position_embeddings = max_position_embeddings
78
+ self.type_vocab_size = type_vocab_size
79
+ self.initializer_range = initializer_range
80
+ self.layer_norm_eps = layer_norm_eps
81
+ else:
82
+ raise ValueError("First argument must be either a vocabulary size (int)"
83
+ " or the path to a pretrained model config file (str)")
model/configuration_utils.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ Configuration base class and utilities."""
17
+
18
+ from __future__ import (absolute_import, division, print_function,
19
+ unicode_literals)
20
+
21
+ import copy
22
+ import json
23
+ import logging
24
+ import os
25
+ from io import open
26
+ from model.file_utils import cached_path, CONFIG_NAME
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ class PretrainedConfig(object):
31
+ r""" Base class for all configuration classes.
32
+ Handles a few parameters tools to all models' configurations as well as methods for loading/downloading/saving configurations.
33
+
34
+ Note:
35
+ A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights.
36
+ It only affects the model's configuration.
37
+
38
+ Class attributes (overridden by derived classes):
39
+ - ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values.
40
+
41
+ Parameters:
42
+ ``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
43
+ ``num_labels``: integer, default `2`. Number of classes to use when the model is a classification model (sequences/tokens)
44
+ ``output_attentions``: boolean, default `False`. Should the model returns attentions weights.
45
+ ``output_hidden_states``: string, default `False`. Should the model returns all hidden-states.
46
+ ``torchscript``: string, default `False`. Is the model used with Torchscript.
47
+ """
48
+ pretrained_config_archive_map = {}
49
+
50
+ def __init__(self, **kwargs):
51
+ self.finetuning_task = kwargs.pop('finetuning_task', None)
52
+ self.num_labels = kwargs.pop('num_labels', 2)
53
+ self.output_attentions = kwargs.pop('output_attentions', False)
54
+ self.output_hidden_states = kwargs.pop('output_hidden_states', False)
55
+ self.torchscript = kwargs.pop('torchscript', False)
56
+ self.pruned_heads = kwargs.pop('pruned_heads', {})
57
+
58
+ def save_pretrained(self, save_directory):
59
+ """ Save a configuration object to the directory `save_directory`, so that it
60
+ can be re-loaded using the :func:`~pytorch_transformers.PretrainedConfig.from_pretrained` class method.
61
+ """
62
+ assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
63
+
64
+ # If we save using the predefined names, we can load using `from_pretrained`
65
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
66
+
67
+ self.to_json_file(output_config_file)
68
+
69
+ @classmethod
70
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
71
+ r""" Instantiate a :class:`~pytorch_transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.
72
+
73
+ Parameters:
74
+ pretrained_model_name_or_path: either:
75
+
76
+ - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``.
77
+ - a path to a `directory` containing a configuration file saved using the :func:`~pytorch_transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
78
+ - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``.
79
+
80
+ cache_dir: (`optional`) string:
81
+ Path to a directory in which a downloaded pre-trained model
82
+ configuration should be cached if the standard cache should not be used.
83
+
84
+ kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading.
85
+
86
+ - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values.
87
+ - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter.
88
+
89
+ force_download: (`optional`) boolean, default False:
90
+ Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
91
+
92
+ proxies: (`optional`) dict, default None:
93
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
94
+ The proxies are used on each request.
95
+
96
+ return_unused_kwargs: (`optional`) bool:
97
+
98
+ - If False, then this function returns just the final configuration object.
99
+ - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored.
100
+
101
+ Examples::
102
+
103
+ # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
104
+ # derived class: BertConfig
105
+ config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
106
+ config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
107
+ config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
108
+ config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
109
+ assert config.output_attention == True
110
+ config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
111
+ foo=False, return_unused_kwargs=True)
112
+ assert config.output_attention == True
113
+ assert unused_kwargs == {'foo': False}
114
+
115
+ """
116
+ cache_dir = kwargs.pop('cache_dir', None)
117
+ force_download = kwargs.pop('force_download', False)
118
+ proxies = kwargs.pop('proxies', None)
119
+ return_unused_kwargs = kwargs.pop('return_unused_kwargs', False)
120
+
121
+ if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
122
+ config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
123
+ elif os.path.isdir(pretrained_model_name_or_path):
124
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
125
+ else:
126
+ config_file = pretrained_model_name_or_path
127
+ # redirect to the cache, if necessary
128
+ try:
129
+ resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
130
+ except EnvironmentError as e:
131
+ if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
132
+ logger.error(
133
+ "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
134
+ config_file))
135
+ else:
136
+ logger.error(
137
+ "Model name '{}' was not found in model name list ({}). "
138
+ "We assumed '{}' was a path or url but couldn't find any file "
139
+ "associated to this path or url.".format(
140
+ pretrained_model_name_or_path,
141
+ ', '.join(cls.pretrained_config_archive_map.keys()),
142
+ config_file))
143
+ raise e
144
+ if resolved_config_file == config_file:
145
+ logger.info("loading configuration file {}".format(config_file))
146
+ else:
147
+ logger.info("loading configuration file {} from cache at {}".format(
148
+ config_file, resolved_config_file))
149
+
150
+ # Load config
151
+ config = cls.from_json_file(resolved_config_file)
152
+
153
+ if hasattr(config, 'pruned_heads'):
154
+ config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items())
155
+
156
+ # Update config with kwargs if needed
157
+ to_remove = []
158
+ for key, value in kwargs.items():
159
+ if hasattr(config, key):
160
+ setattr(config, key, value)
161
+ to_remove.append(key)
162
+ else:
163
+ setattr(config,key,value)
164
+ for key in to_remove:
165
+ kwargs.pop(key, None)
166
+
167
+ logger.info("Model config %s", config)
168
+ if return_unused_kwargs:
169
+ return config, kwargs
170
+ else:
171
+ return config
172
+
173
+ @classmethod
174
+ def from_dict(cls, json_object):
175
+ """Constructs a `Config` from a Python dictionary of parameters."""
176
+ config = cls(vocab_size_or_config_json_file=-1)
177
+ for key, value in json_object.items():
178
+ config.__dict__[key] = value
179
+ return config
180
+
181
+ @classmethod
182
+ def from_json_file(cls, json_file):
183
+ """Constructs a `BertConfig` from a json file of parameters."""
184
+ with open(json_file, "r", encoding='utf-8') as reader:
185
+ text = reader.read()
186
+ return cls.from_dict(json.loads(text))
187
+
188
+ def __eq__(self, other):
189
+ return self.__dict__ == other.__dict__
190
+
191
+ def __repr__(self):
192
+ return str(self.to_json_string())
193
+
194
+ def to_dict(self):
195
+ """Serializes this instance to a Python dictionary."""
196
+ output = copy.deepcopy(self.__dict__)
197
+ return output
198
+
199
+ def to_json_string(self):
200
+ """Serializes this instance to a JSON string."""
201
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
202
+
203
+ def to_json_file(self, json_file_path):
204
+ """ Save this instance to a json file."""
205
+ with open(json_file_path, "w", encoding='utf-8') as writer:
206
+ writer.write(self.to_json_string())
model/file_utils.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for working with the local dataset cache.
3
+ This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4
+ Copyright by the AllenNLP authors.
5
+ """
6
+ from __future__ import (absolute_import, division, print_function, unicode_literals)
7
+
8
+ import sys
9
+ import json
10
+ import logging
11
+ import os
12
+ import six
13
+ import shutil
14
+ import tempfile
15
+ import fnmatch
16
+ from functools import wraps
17
+ from hashlib import sha256
18
+ from io import open
19
+
20
+ import boto3
21
+ from botocore.config import Config
22
+ from botocore.exceptions import ClientError
23
+ import requests
24
+ from tqdm import tqdm
25
+
26
+ try:
27
+ from torch.hub import _get_torch_home
28
+ torch_cache_home = _get_torch_home()
29
+ except ImportError:
30
+ torch_cache_home = os.path.expanduser(
31
+ os.getenv('TORCH_HOME', os.path.join(
32
+ os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
33
+ default_cache_path = os.path.join(torch_cache_home, 'pytorch_transformers')
34
+
35
+ try:
36
+ from urllib.parse import urlparse
37
+ except ImportError:
38
+ from urlparse import urlparse
39
+
40
+ try:
41
+ from pathlib import Path
42
+ PYTORCH_PRETRAINED_BERT_CACHE = Path(
43
+ os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path)))
44
+ except (AttributeError, ImportError):
45
+ PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE',
46
+ os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
47
+ default_cache_path))
48
+
49
+ PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility
50
+
51
+ WEIGHTS_NAME = "pytorch_model.bin"
52
+ TF_WEIGHTS_NAME = 'model.ckpt'
53
+ CONFIG_NAME = "config.json"
54
+
55
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
56
+
57
+ if not six.PY2:
58
+ def add_start_docstrings(*docstr):
59
+ def docstring_decorator(fn):
60
+ fn.__doc__ = ''.join(docstr) + fn.__doc__
61
+ return fn
62
+ return docstring_decorator
63
+
64
+ def add_end_docstrings(*docstr):
65
+ def docstring_decorator(fn):
66
+ fn.__doc__ = fn.__doc__ + ''.join(docstr)
67
+ return fn
68
+ return docstring_decorator
69
+ else:
70
+ # Not possible to update class docstrings on python2
71
+ def add_start_docstrings(*docstr):
72
+ def docstring_decorator(fn):
73
+ return fn
74
+ return docstring_decorator
75
+
76
+ def add_end_docstrings(*docstr):
77
+ def docstring_decorator(fn):
78
+ return fn
79
+ return docstring_decorator
80
+
81
+ def url_to_filename(url, etag=None):
82
+ """
83
+ Convert `url` into a hashed filename in a repeatable way.
84
+ If `etag` is specified, append its hash to the url's, delimited
85
+ by a period.
86
+ """
87
+ url_bytes = url.encode('utf-8')
88
+ url_hash = sha256(url_bytes)
89
+ filename = url_hash.hexdigest()
90
+
91
+ if etag:
92
+ etag_bytes = etag.encode('utf-8')
93
+ etag_hash = sha256(etag_bytes)
94
+ filename += '.' + etag_hash.hexdigest()
95
+
96
+ return filename
97
+
98
+
99
+ def filename_to_url(filename, cache_dir=None):
100
+ """
101
+ Return the url and etag (which may be ``None``) stored for `filename`.
102
+ Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
103
+ """
104
+ if cache_dir is None:
105
+ cache_dir = PYTORCH_TRANSFORMERS_CACHE
106
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
107
+ cache_dir = str(cache_dir)
108
+
109
+ cache_path = os.path.join(cache_dir, filename)
110
+ if not os.path.exists(cache_path):
111
+ raise EnvironmentError("file {} not found".format(cache_path))
112
+
113
+ meta_path = cache_path + '.json'
114
+ if not os.path.exists(meta_path):
115
+ raise EnvironmentError("file {} not found".format(meta_path))
116
+
117
+ with open(meta_path, encoding="utf-8") as meta_file:
118
+ metadata = json.load(meta_file)
119
+ url = metadata['url']
120
+ etag = metadata['etag']
121
+
122
+ return url, etag
123
+
124
+
125
+ def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None):
126
+ """
127
+ Given something that might be a URL (or might be a local path),
128
+ determine which. If it's a URL, download the file and cache it, and
129
+ return the path to the cached file. If it's already a local path,
130
+ make sure the file exists and then return the path.
131
+ Args:
132
+ cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
133
+ force_download: if True, re-dowload the file even if it's already cached in the cache dir.
134
+ """
135
+ if cache_dir is None:
136
+ cache_dir = PYTORCH_TRANSFORMERS_CACHE
137
+ if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
138
+ url_or_filename = str(url_or_filename)
139
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
140
+ cache_dir = str(cache_dir)
141
+
142
+ parsed = urlparse(url_or_filename)
143
+
144
+ if parsed.scheme in ('http', 'https', 's3'):
145
+ # URL, so get it from the cache (downloading if necessary)
146
+ return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
147
+ elif os.path.exists(url_or_filename):
148
+ # File, and it exists.
149
+ return url_or_filename
150
+ elif parsed.scheme == '':
151
+ # File, but it doesn't exist.
152
+ raise EnvironmentError("file {} not found".format(url_or_filename))
153
+ else:
154
+ # Something unknown
155
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
156
+
157
+
158
+ def split_s3_path(url):
159
+ """Split a full s3 path into the bucket name and path."""
160
+ parsed = urlparse(url)
161
+ if not parsed.netloc or not parsed.path:
162
+ raise ValueError("bad s3 path {}".format(url))
163
+ bucket_name = parsed.netloc
164
+ s3_path = parsed.path
165
+ # Remove '/' at beginning of path.
166
+ if s3_path.startswith("/"):
167
+ s3_path = s3_path[1:]
168
+ return bucket_name, s3_path
169
+
170
+
171
+ def s3_request(func):
172
+ """
173
+ Wrapper function for s3 requests in order to create more helpful error
174
+ messages.
175
+ """
176
+
177
+ @wraps(func)
178
+ def wrapper(url, *args, **kwargs):
179
+ try:
180
+ return func(url, *args, **kwargs)
181
+ except ClientError as exc:
182
+ if int(exc.response["Error"]["Code"]) == 404:
183
+ raise EnvironmentError("file {} not found".format(url))
184
+ else:
185
+ raise
186
+
187
+ return wrapper
188
+
189
+
190
+ @s3_request
191
+ def s3_etag(url, proxies=None):
192
+ """Check ETag on S3 object."""
193
+ s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
194
+ bucket_name, s3_path = split_s3_path(url)
195
+ s3_object = s3_resource.Object(bucket_name, s3_path)
196
+ return s3_object.e_tag
197
+
198
+
199
+ @s3_request
200
+ def s3_get(url, temp_file, proxies=None):
201
+ """Pull a file directly from S3."""
202
+ s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
203
+ bucket_name, s3_path = split_s3_path(url)
204
+ s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
205
+
206
+
207
+ def http_get(url, temp_file, proxies=None):
208
+ req = requests.get(url, stream=True, proxies=proxies)
209
+ content_length = req.headers.get('Content-Length')
210
+ total = int(content_length) if content_length is not None else None
211
+ progress = tqdm(unit="B", total=total)
212
+ for chunk in req.iter_content(chunk_size=1024):
213
+ if chunk: # filter out keep-alive new chunks
214
+ progress.update(len(chunk))
215
+ temp_file.write(chunk)
216
+ progress.close()
217
+
218
+
219
+ def get_from_cache(url, cache_dir=None, force_download=False, proxies=None):
220
+ """
221
+ Given a URL, look for the corresponding dataset in the local cache.
222
+ If it's not there, download it. Then return the path to the cached file.
223
+ """
224
+ if cache_dir is None:
225
+ cache_dir = PYTORCH_TRANSFORMERS_CACHE
226
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
227
+ cache_dir = str(cache_dir)
228
+ if sys.version_info[0] == 2 and not isinstance(cache_dir, str):
229
+ cache_dir = str(cache_dir)
230
+
231
+ if not os.path.exists(cache_dir):
232
+ os.makedirs(cache_dir)
233
+
234
+ # Get eTag to add to filename, if it exists.
235
+ if url.startswith("s3://"):
236
+ etag = s3_etag(url, proxies=proxies)
237
+ else:
238
+ try:
239
+ response = requests.head(url, allow_redirects=True, proxies=proxies)
240
+ if response.status_code != 200:
241
+ etag = None
242
+ else:
243
+ etag = response.headers.get("ETag")
244
+ except EnvironmentError:
245
+ etag = None
246
+
247
+ if sys.version_info[0] == 2 and etag is not None:
248
+ etag = etag.decode('utf-8')
249
+ filename = url_to_filename(url, etag)
250
+
251
+ # get cache path to put the file
252
+ cache_path = os.path.join(cache_dir, filename)
253
+
254
+ # If we don't have a connection (etag is None) and can't identify the file
255
+ # try to get the last downloaded one
256
+ if not os.path.exists(cache_path) and etag is None:
257
+ matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
258
+ matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
259
+ if matching_files:
260
+ cache_path = os.path.join(cache_dir, matching_files[-1])
261
+
262
+ if not os.path.exists(cache_path) or force_download:
263
+ # Download to temporary file, then copy to cache dir once finished.
264
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
265
+ with tempfile.NamedTemporaryFile() as temp_file:
266
+ logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
267
+
268
+ # GET file object
269
+ if url.startswith("s3://"):
270
+ s3_get(url, temp_file, proxies=proxies)
271
+ else:
272
+ http_get(url, temp_file, proxies=proxies)
273
+
274
+ # we are copying the file before closing it, so flush to avoid truncation
275
+ temp_file.flush()
276
+ # shutil.copyfileobj() starts at the current position, so go to the start
277
+ temp_file.seek(0)
278
+
279
+ logger.info("copying %s to cache at %s", temp_file.name, cache_path)
280
+ with open(cache_path, 'wb') as cache_file:
281
+ shutil.copyfileobj(temp_file, cache_file)
282
+
283
+ logger.info("creating metadata file for %s", cache_path)
284
+ meta = {'url': url, 'etag': etag}
285
+ meta_path = cache_path + '.json'
286
+ with open(meta_path, 'w') as meta_file:
287
+ output_string = json.dumps(meta)
288
+ if sys.version_info[0] == 2 and isinstance(output_string, str):
289
+ output_string = unicode(output_string, 'utf-8') # The beauty of python 2
290
+ meta_file.write(output_string)
291
+
292
+ logger.info("removing temp file %s", temp_file.name)
293
+
294
+ return cache_path
model/modeling_albert.py ADDED
@@ -0,0 +1,1088 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch ALBERT model. """
2
+ from __future__ import absolute_import, division, print_function, unicode_literals
3
+ import logging
4
+ import math
5
+ import os
6
+ import sys
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import CrossEntropyLoss, MSELoss
10
+ from .modeling_utils import PreTrainedModel, prune_linear_layer
11
+ from .configuration_albert import AlbertConfig
12
+ from .file_utils import add_start_docstrings
13
+ logger = logging.getLogger(__name__)
14
+
15
+ ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
16
+ 'albert-base': "",
17
+ 'albert-large': "",
18
+ 'albert-xlarge': "",
19
+ 'albert-xxlarge': "",
20
+ }
21
+ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
22
+ """ Load tf checkpoints in a pytorch model.
23
+ """
24
+ try:
25
+ import re
26
+ import numpy as np
27
+ import tensorflow as tf
28
+ except ImportError:
29
+ logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
30
+ "https://www.tensorflow.org/install/ for installation instructions.")
31
+ raise
32
+ tf_path = os.path.abspath(tf_checkpoint_path)
33
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
34
+ if not os.path.exists(tf_path+'/checkpoint'):
35
+ tf_path = tf_path + "/variables/variables"
36
+ # Load weights from TF model
37
+ init_vars = tf.train.list_variables(tf_path)
38
+ names = []
39
+ arrays = []
40
+ for name, shape in init_vars:
41
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
42
+ array = tf.train.load_variable(tf_path, name)
43
+ names.append(name)
44
+ arrays.append(array)
45
+ for name, array in zip(names, arrays):
46
+ name = name.replace("attention_1","attention")
47
+ name = name.replace("ffn_1","ffn")
48
+ name = name.split('/')
49
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
50
+ # which are not required for using pretrained model
51
+ if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
52
+ logger.info("Skipping {}".format("/".join(name)))
53
+ continue
54
+ pointer = model
55
+ for m_name in name:
56
+ if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
57
+ l = re.split(r'_(\d+)', m_name)
58
+ elif re.fullmatch(r'[A-Za-z]+_+[A-Za-z]+_\d+', m_name):
59
+ l = re.split(r'_(\d+)', m_name)
60
+ else:
61
+ l = [m_name]
62
+ if l[0] in ['LayerNorm', 'attention', 'ffn'] and len(l) >= 2:
63
+ l = ["_".join(l[:-1])]
64
+ if l[0] == 'kernel' or l[0] == 'gamma':
65
+ pointer = getattr(pointer, 'weight')
66
+ elif l[0] == 'output_bias' or l[0] == 'beta':
67
+ pointer = getattr(pointer, 'bias')
68
+ elif l[0] == 'output_weights':
69
+ pointer = getattr(pointer, 'weight')
70
+ elif l[0] == 'squad':
71
+ pointer = getattr(pointer, 'classifier')
72
+ else:
73
+ try:
74
+ pointer = getattr(pointer, l[0])
75
+ except AttributeError:
76
+ logger.info("Skipping {}".format("/".join(name)))
77
+ continue
78
+ if len(l) >= 2:
79
+ num = int(l[1])
80
+ pointer = pointer[num]
81
+
82
+ if m_name[-11:] == '_embeddings':
83
+ pointer = getattr(pointer, 'weight')
84
+ elif m_name == 'kernel':
85
+ array = np.transpose(array)
86
+ try:
87
+ assert pointer.shape == array.shape
88
+ except AssertionError as e:
89
+ e.args += (pointer.shape, array.shape)
90
+ raise
91
+ logger.info("Initialize PyTorch weight {}".format(name))
92
+ pointer.data = torch.from_numpy(array)
93
+ return model
94
+
95
+ def gelu(x):
96
+ """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
97
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
98
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
99
+ Also see https://arxiv.org/abs/1606.08415
100
+ """
101
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
102
+
103
+ def gelu_new(x):
104
+ """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
105
+ Also see https://arxiv.org/abs/1606.08415
106
+ """
107
+ return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
108
+
109
+ def swish(x):
110
+ return x * torch.sigmoid(x)
111
+
112
+ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new}
113
+ AlbertLayerNorm = torch.nn.LayerNorm
114
+
115
+ class AlbertEmbeddings(nn.Module):
116
+ """Construct the embeddings from word, position and token_type embeddings.
117
+ """
118
+ def __init__(self, config):
119
+ super(AlbertEmbeddings, self).__init__()
120
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=0)
121
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
122
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
123
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
124
+ self.LayerNorm = AlbertLayerNorm(config.embedding_size, eps=config.layer_norm_eps)
125
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
126
+
127
+ def forward(self, input_ids, token_type_ids=None, position_ids=None):
128
+ seq_length = input_ids.size(1)
129
+ if position_ids is None:
130
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
131
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
132
+ if token_type_ids is None:
133
+ token_type_ids = torch.zeros_like(input_ids)
134
+ words_embeddings = self.word_embeddings(input_ids)
135
+ position_embeddings = self.position_embeddings(position_ids)
136
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
137
+ embeddings = words_embeddings + position_embeddings + token_type_embeddings
138
+ embeddings = self.LayerNorm(embeddings)
139
+ embeddings = self.dropout(embeddings)
140
+ return embeddings
141
+
142
+ class AlbertSelfAttention(nn.Module):
143
+ def __init__(self, config):
144
+ super(AlbertSelfAttention, self).__init__()
145
+ if config.hidden_size % config.num_attention_heads != 0:
146
+ raise ValueError(
147
+ "The hidden size (%d) is not a multiple of the number of attention "
148
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
149
+ self.output_attentions = config.output_attentions
150
+ self.num_attention_heads = config.num_attention_heads
151
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
152
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
153
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
154
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
155
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
156
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
157
+
158
+ def transpose_for_scores(self, x):
159
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
160
+ x = x.view(*new_x_shape)
161
+ return x.permute(0, 2, 1, 3)
162
+
163
+ def forward(self, hidden_states, attention_mask=None, head_mask=None):
164
+ mixed_query_layer = self.query(hidden_states)
165
+ mixed_key_layer = self.key(hidden_states)
166
+ mixed_value_layer = self.value(hidden_states)
167
+
168
+ query_layer = self.transpose_for_scores(mixed_query_layer)
169
+ key_layer = self.transpose_for_scores(mixed_key_layer)
170
+ value_layer = self.transpose_for_scores(mixed_value_layer)
171
+
172
+ # Take the dot product between "query" and "key" to get the raw attention scores.
173
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
174
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
175
+ if attention_mask is not None:
176
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
177
+ attention_scores = attention_scores + attention_mask
178
+
179
+ # Normalize the attention scores to probabilities.
180
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
181
+
182
+ # This is actually dropping out entire tokens to attend to, which might
183
+ # seem a bit unusual, but is taken from the original Transformer paper.
184
+ attention_probs = self.dropout(attention_probs)
185
+
186
+ # Mask heads if we want to
187
+ if head_mask is not None:
188
+ attention_probs = attention_probs * head_mask
189
+
190
+ context_layer = torch.matmul(attention_probs, value_layer)
191
+
192
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
193
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
194
+ context_layer = context_layer.view(*new_context_layer_shape)
195
+ outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
196
+ return outputs
197
+
198
+ class AlbertSelfOutput(nn.Module):
199
+ def __init__(self, config):
200
+ super(AlbertSelfOutput, self).__init__()
201
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
202
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
203
+ def forward(self, hidden_states, input_tensor):
204
+ hidden_states = self.dense(hidden_states)
205
+ hidden_states = self.dropout(hidden_states)
206
+ return hidden_states
207
+
208
+ class AlbertAttention(nn.Module):
209
+ def __init__(self, config):
210
+ super(AlbertAttention, self).__init__()
211
+ self.self = AlbertSelfAttention(config)
212
+ self.output = AlbertSelfOutput(config)
213
+ self.pruned_heads = set()
214
+
215
+ def prune_heads(self, heads):
216
+ if len(heads) == 0:
217
+ return
218
+ mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
219
+ heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
220
+ for head in heads:
221
+ # Compute how many pruned heads are before the head and move the index accordingly
222
+ head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
223
+ mask[head] = 0
224
+ mask = mask.view(-1).contiguous().eq(1)
225
+ index = torch.arange(len(mask))[mask].long()
226
+
227
+ # Prune linear layers
228
+ self.self.query = prune_linear_layer(self.self.query, index)
229
+ self.self.key = prune_linear_layer(self.self.key, index)
230
+ self.self.value = prune_linear_layer(self.self.value, index)
231
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
232
+
233
+ # Update hyper params and store pruned heads
234
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
235
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
236
+ self.pruned_heads = self.pruned_heads.union(heads)
237
+
238
+ def forward(self, input_tensor, attention_mask=None, head_mask=None):
239
+ self_outputs = self.self(input_tensor, attention_mask, head_mask)
240
+ attention_output = self.output(self_outputs[0], input_tensor)
241
+ outputs = (attention_output,self_outputs)
242
+ return outputs
243
+
244
+ class AlbertOutput(nn.Module):
245
+ def __init__(self, config):
246
+ super(AlbertOutput, self).__init__()
247
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
248
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
249
+
250
+ def forward(self, hidden_states):
251
+ hidden_states = self.dense(hidden_states)
252
+ hidden_states = self.dropout(hidden_states)
253
+ return hidden_states
254
+
255
+ class AlbertIntermediate(nn.Module):
256
+ def __init__(self, config):
257
+ super(AlbertIntermediate, self).__init__()
258
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
259
+ self.output = AlbertOutput(config)
260
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
261
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
262
+ else:
263
+ self.intermediate_act_fn = config.hidden_act
264
+
265
+ def forward(self, hidden_states):
266
+ intermediate_output = self.dense(hidden_states)
267
+ intermediate_output = self.intermediate_act_fn(intermediate_output)
268
+ output = self.output(intermediate_output)
269
+ return output
270
+
271
+ class AlbertFFN(nn.Module):
272
+ def __init__(self, config):
273
+ super(AlbertFFN, self).__init__()
274
+ self.intermediate = AlbertIntermediate(config)
275
+
276
+ def forward(self, attention_output):
277
+ output = self.intermediate(attention_output)
278
+ return output
279
+
280
+ class AlbertLayer(nn.Module):
281
+ def __init__(self, config):
282
+ super(AlbertLayer, self).__init__()
283
+ self.attention = AlbertAttention(config)
284
+ self.ffn = AlbertFFN(config)
285
+ self.LayerNorm = AlbertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
286
+ self.LayerNorm_1 = AlbertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
287
+
288
+ def forward(self, hidden_states, attention_mask=None, head_mask=None):
289
+ attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
290
+ attention_output = self.LayerNorm(attention_outputs[0] + hidden_states)
291
+ ffn_output = self.ffn(attention_output)
292
+ ffn_output = self.LayerNorm_1(ffn_output+attention_output)
293
+ outputs = (ffn_output,) + attention_outputs[1:] # add attentions if we output them
294
+ return outputs
295
+
296
+ class AlbertGroup(nn.Module):
297
+ def __init__(self, config):
298
+ super(AlbertGroup, self).__init__()
299
+ self.inner_group_num = config.inner_group_num
300
+ self.inner_group = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
301
+
302
+ def forward(self, hidden_states, attention_mask, head_mask):
303
+ layer_attentions = ()
304
+ layer_hidden_states = ()
305
+ for inner_group_idx in range(self.inner_group_num):
306
+ layer_module = self.inner_group[inner_group_idx]
307
+ layer_outputs = layer_module(hidden_states, attention_mask, head_mask)
308
+ hidden_states = layer_outputs[0]
309
+ layer_attentions = layer_attentions + (layer_outputs[1],)
310
+ layer_hidden_states = layer_hidden_states + (hidden_states,)
311
+ return (layer_hidden_states, layer_attentions)
312
+
313
+ class AlbertTransformer(nn.Module):
314
+ def __init__(self, config):
315
+ super(AlbertTransformer, self).__init__()
316
+ self.output_attentions = config.output_attentions
317
+ self.output_hidden_states = config.output_hidden_states
318
+ self.num_hidden_layers = config.num_hidden_layers
319
+ self.num_hidden_groups = config.num_hidden_groups
320
+ self.group = nn.ModuleList([AlbertGroup(config) for _ in range(config.num_hidden_groups)])
321
+
322
+ def forward(self, hidden_states, attention_mask, head_mask):
323
+ all_hidden_states = ()
324
+ all_attentions = ()
325
+ for layer_idx in range(self.num_hidden_layers):
326
+ if self.output_hidden_states and layer_idx == 0:
327
+ all_hidden_states = all_hidden_states + (hidden_states,)
328
+ group_idx = int(layer_idx / self.num_hidden_layers * self.num_hidden_groups)
329
+ layer_module = self.group[group_idx]
330
+ layer_outputs = layer_module(hidden_states, attention_mask, head_mask[layer_idx])
331
+ hidden_states = layer_outputs[0][-1]
332
+ if self.output_attentions:
333
+ all_attentions = all_attentions + layer_outputs[1]
334
+ if self.output_hidden_states:
335
+ all_hidden_states = all_hidden_states + layer_outputs[0]
336
+ outputs = (hidden_states,)
337
+ if self.output_hidden_states:
338
+ outputs = outputs + (all_hidden_states,)
339
+ if self.output_attentions:
340
+ outputs = outputs + (all_attentions,)
341
+ return outputs # last-layer hidden state, (all hidden states), (all attentions)
342
+
343
+ class AlbertEncoder(nn.Module):
344
+ def __init__(self, config):
345
+ super(AlbertEncoder, self).__init__()
346
+ self.hidden_size = config.hidden_size
347
+ self.embedding_size = config.embedding_size
348
+ self.embedding_hidden_mapping_in = nn.Linear(self.embedding_size, self.hidden_size)
349
+ self.transformer = AlbertTransformer(config)
350
+
351
+ def forward(self, hidden_states, attention_mask=None, head_mask=None):
352
+ if self.embedding_size != self.hidden_size:
353
+ prev_output = self.embedding_hidden_mapping_in(hidden_states)
354
+ else:
355
+ prev_output = hidden_states
356
+ outputs = self.transformer(prev_output, attention_mask, head_mask)
357
+ return outputs # last-layer hidden state, (all hidden states), (all attentions)
358
+
359
+ class AlbertPooler(nn.Module):
360
+ def __init__(self, config):
361
+ super(AlbertPooler, self).__init__()
362
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
363
+ self.activation = nn.Tanh()
364
+
365
+ def forward(self, hidden_states):
366
+ # We "pool" the model by simply taking the hidden state corresponding
367
+ # to the first token.
368
+ first_token_tensor = hidden_states[:, 0]
369
+ pooled_output = self.dense(first_token_tensor)
370
+ pooled_output = self.activation(pooled_output)
371
+ return pooled_output
372
+
373
+ class AlbertPredictionHeadTransform(nn.Module):
374
+ def __init__(self, config):
375
+ super(AlbertPredictionHeadTransform, self).__init__()
376
+ self.dense = nn.Linear(config.hidden_size, config.embedding_size)
377
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
378
+ self.transform_act_fn = ACT2FN[config.hidden_act]
379
+ else:
380
+ self.transform_act_fn = config.hidden_act
381
+ self.LayerNorm = AlbertLayerNorm(config.embedding_size, eps=config.layer_norm_eps)
382
+
383
+ def forward(self, hidden_states):
384
+ hidden_states = self.dense(hidden_states)
385
+ hidden_states = self.transform_act_fn(hidden_states)
386
+ hidden_states = self.LayerNorm(hidden_states)
387
+ return hidden_states
388
+
389
+ class AlbertLMPredictionHead(nn.Module):
390
+ def __init__(self, config):
391
+ super(AlbertLMPredictionHead, self).__init__()
392
+ self.transform = AlbertPredictionHeadTransform(config)
393
+ # The output weights are the same as the input embeddings, but there is
394
+ # an output-only bias for each token.
395
+ self.decoder = nn.Linear(config.embedding_size,config.vocab_size,bias=False)
396
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
397
+
398
+ def forward(self, hidden_states):
399
+ hidden_states = self.transform(hidden_states)
400
+ hidden_states = self.decoder(hidden_states) + self.bias
401
+ return hidden_states
402
+
403
+ class AlbertOnlyMLMHead(nn.Module):
404
+ def __init__(self, config):
405
+ super(AlbertOnlyMLMHead, self).__init__()
406
+ self.predictions = AlbertLMPredictionHead(config)
407
+
408
+ def forward(self, sequence_output):
409
+ prediction_scores = self.predictions(sequence_output)
410
+ return prediction_scores
411
+
412
+ class AlbertOnlyNSPHead(nn.Module):
413
+ def __init__(self, config):
414
+ super(AlbertOnlyNSPHead, self).__init__()
415
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
416
+
417
+ def forward(self, pooled_output):
418
+ seq_relationship_score = self.seq_relationship(pooled_output)
419
+ return seq_relationship_score
420
+
421
+ class AlbertPreTrainingHeads(nn.Module):
422
+ def __init__(self, config):
423
+ super(AlbertPreTrainingHeads, self).__init__()
424
+ self.predictions = AlbertLMPredictionHead(config)
425
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
426
+
427
+ def forward(self, sequence_output, pooled_output):
428
+ prediction_scores = self.predictions(sequence_output)
429
+ seq_relationship_score = self.seq_relationship(pooled_output)
430
+ return prediction_scores, seq_relationship_score
431
+
432
+ class AlbertPreTrainedModel(PreTrainedModel):
433
+ """ An abstract class to handle weights initialization and
434
+ a simple interface for dowloading and loading pretrained models.
435
+ """
436
+ config_class = AlbertConfig
437
+ pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
438
+ load_tf_weights = load_tf_weights_in_albert
439
+ base_model_prefix = "bert"
440
+
441
+ def _init_weights(self, module):
442
+ """ Initialize the weights """
443
+ if isinstance(module, (nn.Linear, nn.Embedding)):
444
+ # Slightly different from the TF version which uses truncated_normal for initialization
445
+ # cf https://github.com/pytorch/pytorch/pull/5617
446
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
447
+ elif isinstance(module, AlbertLayerNorm):
448
+ module.bias.data.zero_()
449
+ module.weight.data.fill_(1.0)
450
+ if isinstance(module, nn.Linear) and module.bias is not None:
451
+ module.bias.data.zero_()
452
+
453
+
454
+ ALBERT_START_DOCSTRING = r""" The ALBERT model was proposed in
455
+ `ALBERT: A Lite BERT for Self-supervised Learning of Language Representations`_
456
+ by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.
457
+ This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
458
+ refer to the PyTorch documentation for all matter related to general usage and behavior.
459
+ .. _`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`:
460
+ https://arxiv.org/abs/1909.11942
461
+ .. _`torch.nn.Module`:
462
+ https://pytorch.org/docs/stable/nn.html#module
463
+ Parameters:
464
+ config (:class:`~transformers.ALbertConfig`): Model configuration class with all the parameters of the model.
465
+ Initializing with a config file does not load the weights associated with the model, only the configuration.
466
+ Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
467
+ """
468
+
469
+ ALBERT_INPUTS_DOCSTRING = r"""
470
+ Inputs:
471
+ **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
472
+ Indices of input sequence tokens in the vocabulary.
473
+ To match pre-training, ALBERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:
474
+ (a) For sequence pairs:
475
+ ``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
476
+ ``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
477
+ (b) For single sequences:
478
+ ``tokens: [CLS] the dog is hairy . [SEP]``
479
+ ``token_type_ids: 0 0 0 0 0 0 0``
480
+ ALBert is a model with absolute position embeddings so it's usually advised to pad the inputs on
481
+ the right rather than the left.
482
+ Indices can be obtained using :class:`transformers.BertTokenizer`.
483
+ See :func:`transformers.PreTrainedTokenizer.encode` and
484
+ :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
485
+ **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
486
+ Mask to avoid performing attention on padding token indices.
487
+ Mask values selected in ``[0, 1]``:
488
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
489
+ **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
490
+ Segment token indices to indicate first and second portions of the inputs.
491
+ Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
492
+ corresponds to a `sentence B` token
493
+ (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
494
+ **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
495
+ Indices of positions of each input sequence tokens in the position embeddings.
496
+ Selected in the range ``[0, config.max_position_embeddings - 1]``.
497
+ **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
498
+ Mask to nullify selected heads of the self-attention modules.
499
+ Mask values selected in ``[0, 1]``:
500
+ ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
501
+ """
502
+
503
+ @add_start_docstrings("The bare Albert Model transformer outputting raw hidden-states without any specific head on top.",
504
+ ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
505
+ class AlbertModel(AlbertPreTrainedModel):
506
+ r"""
507
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
508
+ **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
509
+ Sequence of hidden-states at the output of the last layer of the model.
510
+ **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
511
+ Last layer hidden-state of the first token of the sequence (classification token)
512
+ further processed by a Linear layer and a Tanh activation function. The Linear
513
+ layer weights are trained from the next sentence prediction (classification)
514
+ objective during Bert pretraining. This output is usually *not* a good summary
515
+ of the semantic content of the input, you're often better with averaging or pooling
516
+ the sequence of hidden-states for the whole input sequence.
517
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
518
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
519
+ of shape ``(batch_size, sequence_length, hidden_size)``:
520
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
521
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
522
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
523
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
524
+ Examples::
525
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
526
+ model = BertModel.from_pretrained('bert-base-uncased')
527
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
528
+ outputs = model(input_ids)
529
+ last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
530
+ """
531
+
532
+ def __init__(self, config):
533
+ super(AlbertModel, self).__init__(config)
534
+
535
+ self.embeddings = AlbertEmbeddings(config)
536
+ self.encoder = AlbertEncoder(config)
537
+ self.pooler = AlbertPooler(config)
538
+
539
+ self.init_weights()
540
+
541
+ def _resize_token_embeddings(self, new_num_tokens):
542
+ old_embeddings = self.embeddings.word_embeddings
543
+ new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
544
+ self.embeddings.word_embeddings = new_embeddings
545
+ return self.embeddings.word_embeddings
546
+
547
+ def _prune_heads(self, heads_to_prune):
548
+ """ Prunes heads of the model.
549
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
550
+ See base class PreTrainedModel
551
+ """
552
+ for layer, heads in heads_to_prune.items():
553
+ self.encoder.layer[layer].attention.prune_heads(heads)
554
+
555
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
556
+ if attention_mask is None:
557
+ attention_mask = torch.ones_like(input_ids)
558
+ if token_type_ids is None:
559
+ token_type_ids = torch.zeros_like(input_ids)
560
+
561
+ # We create a 3D attention mask from a 2D tensor mask.
562
+ # Sizes are [batch_size, 1, 1, to_seq_length]
563
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
564
+ # this attention mask is more simple than the triangular masking of causal attention
565
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
566
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
567
+
568
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
569
+ # masked positions, this operation will create a tensor which is 0.0 for
570
+ # positions we want to attend and -10000.0 for masked positions.
571
+ # Since we are adding it to the raw scores before the softmax, this is
572
+ # effectively the same as removing these entirely.
573
+ extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
574
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
575
+
576
+ # Prepare head mask if needed
577
+ # 1.0 in head_mask indicate we keep the head
578
+ # attention_probs has shape bsz x n_heads x N x N
579
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
580
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
581
+ if head_mask is not None:
582
+ if head_mask.dim() == 1:
583
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
584
+ head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
585
+ elif head_mask.dim() == 2:
586
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(
587
+ -1) # We can specify head_mask for each layer
588
+ head_mask = head_mask.to(
589
+ dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
590
+ else:
591
+ head_mask = [None] * self.config.num_hidden_layers
592
+
593
+ embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
594
+ encoder_outputs = self.encoder(embedding_output,
595
+ extended_attention_mask,
596
+ head_mask=head_mask)
597
+ sequence_output = encoder_outputs[0]
598
+ pooled_output = self.pooler(sequence_output)
599
+
600
+ outputs = (sequence_output, pooled_output,) + encoder_outputs[
601
+ 1:] # add hidden_states and attentions if they are here
602
+ return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
603
+
604
+ @add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
605
+ a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
606
+ ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
607
+ class AlbertForPreTraining(AlbertPreTrainedModel):
608
+ r"""
609
+ **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
610
+ Labels for computing the masked language modeling loss.
611
+ Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
612
+ Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
613
+ in ``[0, ..., config.vocab_size]``
614
+ **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
615
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
616
+ Indices should be in ``[0, 1]``.
617
+ ``0`` indicates sequence B is a continuation of sequence A,
618
+ ``1`` indicates sequence B is a random sequence.
619
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
620
+ **loss**: (`optional`, returned when both ``masked_lm_labels`` and ``next_sentence_label`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
621
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
622
+ **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
623
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
624
+ **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
625
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
626
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
627
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
628
+ of shape ``(batch_size, sequence_length, hidden_size)``:
629
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
630
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
631
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
632
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
633
+ Examples::
634
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
635
+ model = BertForPreTraining.from_pretrained('bert-base-uncased')
636
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
637
+ outputs = model(input_ids)
638
+ prediction_scores, seq_relationship_scores = outputs[:2]
639
+ """
640
+
641
+ def __init__(self, config):
642
+ super(AlbertForPreTraining, self).__init__(config)
643
+ self.bert = AlbertModel(config)
644
+ self.cls = AlbertPreTrainingHeads(config)
645
+
646
+ self.init_weights()
647
+ self.tie_weights()
648
+
649
+ def tie_weights(self):
650
+ """ Make sure we are sharing the input and output embeddings.
651
+ Export to TorchScript can't handle parameter sharing so we are cloning them instead.
652
+ """
653
+ self._tie_or_clone_weights(self.cls.predictions.decoder,
654
+ self.bert.embeddings.word_embeddings)
655
+
656
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
657
+ masked_lm_labels=None, next_sentence_label=None):
658
+ outputs = self.bert(input_ids,
659
+ attention_mask=attention_mask,
660
+ token_type_ids=token_type_ids,
661
+ position_ids=position_ids,
662
+ head_mask=head_mask)
663
+
664
+ sequence_output, pooled_output = outputs[:2]
665
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
666
+
667
+ outputs = (prediction_scores, seq_relationship_score,) + outputs[
668
+ 2:] # add hidden states and attention if they are here
669
+
670
+ if masked_lm_labels is not None and next_sentence_label is not None:
671
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
672
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
673
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
674
+ total_loss = masked_lm_loss + next_sentence_loss
675
+ outputs = (total_loss,) + outputs
676
+ return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
677
+
678
+ @add_start_docstrings("""Bert Model with a `language modeling` head on top. """,
679
+ ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
680
+ class AlbertForMaskedLM(AlbertPreTrainedModel):
681
+ r"""
682
+ **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
683
+ Labels for computing the masked language modeling loss.
684
+ Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
685
+ Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
686
+ in ``[0, ..., config.vocab_size]``
687
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
688
+ **loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
689
+ Masked language modeling loss.
690
+ **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
691
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
692
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
693
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
694
+ of shape ``(batch_size, sequence_length, hidden_size)``:
695
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
696
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
697
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
698
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
699
+ Examples::
700
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
701
+ model = BertForMaskedLM.from_pretrained('bert-base-uncased')
702
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
703
+ outputs = model(input_ids, masked_lm_labels=input_ids)
704
+ loss, prediction_scores = outputs[:2]
705
+ """
706
+
707
+ def __init__(self, config):
708
+ super(AlbertForMaskedLM, self).__init__(config)
709
+
710
+ self.bert = AlbertModel(config)
711
+ self.cls = AlbertOnlyMLMHead(config)
712
+
713
+ self.init_weights()
714
+ self.tie_weights()
715
+
716
+ def tie_weights(self):
717
+ """ Make sure we are sharing the input and output embeddings.
718
+ Export to TorchScript can't handle parameter sharing so we are cloning them instead.
719
+ """
720
+ self._tie_or_clone_weights(self.cls.predictions.decoder,
721
+ self.bert.embeddings.word_embeddings)
722
+
723
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
724
+ masked_lm_labels=None):
725
+ outputs = self.bert(input_ids,
726
+ attention_mask=attention_mask,
727
+ token_type_ids=token_type_ids,
728
+ position_ids=position_ids,
729
+ head_mask=head_mask)
730
+
731
+ sequence_output = outputs[0]
732
+ prediction_scores = self.cls(sequence_output)
733
+
734
+ outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
735
+ if masked_lm_labels is not None:
736
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
737
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
738
+ outputs = (masked_lm_loss,) + outputs
739
+
740
+ return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
741
+
742
+
743
+ @add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
744
+ ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
745
+ class AlbertForNextSentencePrediction(AlbertPreTrainedModel):
746
+ r"""
747
+ **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
748
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
749
+ Indices should be in ``[0, 1]``.
750
+ ``0`` indicates sequence B is a continuation of sequence A,
751
+ ``1`` indicates sequence B is a random sequence.
752
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
753
+ **loss**: (`optional`, returned when ``next_sentence_label`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
754
+ Next sequence prediction (classification) loss.
755
+ **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
756
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
757
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
758
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
759
+ of shape ``(batch_size, sequence_length, hidden_size)``:
760
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
761
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
762
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
763
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
764
+ Examples::
765
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
766
+ model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
767
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
768
+ outputs = model(input_ids)
769
+ seq_relationship_scores = outputs[0]
770
+ """
771
+
772
+ def __init__(self, config):
773
+ super(AlbertForNextSentencePrediction, self).__init__(config)
774
+
775
+ self.bert = AlbertModel(config)
776
+ self.cls = AlbertOnlyNSPHead(config)
777
+
778
+ self.init_weights()
779
+
780
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
781
+ next_sentence_label=None):
782
+ outputs = self.bert(input_ids,
783
+ attention_mask=attention_mask,
784
+ token_type_ids=token_type_ids,
785
+ position_ids=position_ids,
786
+ head_mask=head_mask)
787
+
788
+ pooled_output = outputs[1]
789
+
790
+ seq_relationship_score = self.cls(pooled_output)
791
+
792
+ outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
793
+ if next_sentence_label is not None:
794
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
795
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
796
+ outputs = (next_sentence_loss,) + outputs
797
+
798
+ return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
799
+
800
+
801
+ @add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
802
+ the pooled output) e.g. for GLUE tasks. """,
803
+ ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
804
+ class AlbertForSequenceClassification(AlbertPreTrainedModel):
805
+ r"""
806
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
807
+ Labels for computing the sequence classification/regression loss.
808
+ Indices should be in ``[0, ..., config.num_labels - 1]``.
809
+ If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
810
+ If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
811
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
812
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
813
+ Classification (or regression if config.num_labels==1) loss.
814
+ **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
815
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
816
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
817
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
818
+ of shape ``(batch_size, sequence_length, hidden_size)``:
819
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
820
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
821
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
822
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
823
+ Examples::
824
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
825
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
826
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
827
+ labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
828
+ outputs = model(input_ids, labels=labels)
829
+ loss, logits = outputs[:2]
830
+ """
831
+
832
+ def __init__(self, config):
833
+ super(AlbertForSequenceClassification, self).__init__(config)
834
+ self.num_labels = config.num_labels
835
+
836
+ self.bert = AlbertModel(config)
837
+ self.dropout = nn.Dropout(0.1 if config.hidden_dropout_prob == 0 else config.hidden_dropout_prob)
838
+ self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
839
+
840
+ self.init_weights()
841
+
842
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None,
843
+ position_ids=None, head_mask=None, labels=None):
844
+
845
+ outputs = self.bert(input_ids,
846
+ attention_mask=attention_mask,
847
+ token_type_ids=token_type_ids,
848
+ position_ids=position_ids,
849
+ head_mask=head_mask)
850
+
851
+ pooled_output = outputs[1]
852
+
853
+ pooled_output = self.dropout(pooled_output+0.1)
854
+ logits = self.classifier(pooled_output)
855
+
856
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
857
+
858
+ if labels is not None:
859
+ if self.num_labels == 1:
860
+ # We are doing regression
861
+ loss_fct = MSELoss()
862
+ loss = loss_fct(logits.view(-1), labels.view(-1))
863
+ else:
864
+ loss_fct = CrossEntropyLoss()
865
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
866
+ outputs = (loss,) + outputs
867
+
868
+ return outputs # (loss), logits, (hidden_states), (attentions)
869
+
870
+
871
+ @add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
872
+ the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
873
+ ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
874
+ class AlbertForMultipleChoice(AlbertPreTrainedModel):
875
+ r"""
876
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
877
+ Labels for computing the multiple choice classification loss.
878
+ Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
879
+ of the input tensors. (see `input_ids` above)
880
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
881
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
882
+ Classification loss.
883
+ **classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
884
+ of the input tensors. (see `input_ids` above).
885
+ Classification scores (before SoftMax).
886
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
887
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
888
+ of shape ``(batch_size, sequence_length, hidden_size)``:
889
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
890
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
891
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
892
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
893
+ Examples::
894
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
895
+ model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
896
+ choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
897
+ input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
898
+ labels = torch.tensor(1).unsqueeze(0) # Batch size 1
899
+ outputs = model(input_ids, labels=labels)
900
+ loss, classification_scores = outputs[:2]
901
+ """
902
+
903
+ def __init__(self, config):
904
+ super(AlbertForMultipleChoice, self).__init__(config)
905
+
906
+ self.bert = AlbertModel(config)
907
+ self.dropout = nn.Dropout(0.1 if config.hidden_dropout_prob == 0 else config.hidden_dropout_prob)
908
+ self.classifier = nn.Linear(config.hidden_size, 1)
909
+
910
+ self.init_weights()
911
+
912
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None,
913
+ position_ids=None, head_mask=None, labels=None):
914
+ num_choices = input_ids.shape[1]
915
+
916
+ input_ids = input_ids.view(-1, input_ids.size(-1))
917
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
918
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
919
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
920
+ outputs = self.bert(input_ids,
921
+ attention_mask=attention_mask,
922
+ token_type_ids=token_type_ids,
923
+ position_ids=position_ids,
924
+ head_mask=head_mask)
925
+ pooled_output = outputs[1]
926
+ pooled_output = self.dropout(pooled_output)
927
+ logits = self.classifier(pooled_output)
928
+ reshaped_logits = logits.view(-1, num_choices)
929
+ outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
930
+ if labels is not None:
931
+ loss_fct = CrossEntropyLoss()
932
+ loss = loss_fct(reshaped_logits, labels)
933
+ outputs = (loss,) + outputs
934
+
935
+ return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
936
+
937
+
938
+ @add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
939
+ the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
940
+ ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
941
+
942
+ class AlbertForTokenClassification(AlbertPreTrainedModel):
943
+ r"""
944
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
945
+ Labels for computing the token classification loss.
946
+ Indices should be in ``[0, ..., config.num_labels - 1]``.
947
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
948
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
949
+ Classification loss.
950
+ **scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
951
+ Classification scores (before SoftMax).
952
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
953
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
954
+ of shape ``(batch_size, sequence_length, hidden_size)``:
955
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
956
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
957
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
958
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
959
+ Examples::
960
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
961
+ model = BertForTokenClassification.from_pretrained('bert-base-uncased')
962
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
963
+ labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
964
+ outputs = model(input_ids, labels=labels)
965
+ loss, scores = outputs[:2]
966
+ """
967
+
968
+ def __init__(self, config):
969
+ super(AlbertForTokenClassification, self).__init__(config)
970
+ self.num_labels = config.num_labels
971
+
972
+ self.bert = AlbertModel(config)
973
+ self.dropout = nn.Dropout(0.1 if config.hidden_dropout_prob == 0 else config.hidden_dropout_prob)
974
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
975
+
976
+ self.init_weights()
977
+
978
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None,
979
+ position_ids=None, head_mask=None, labels=None):
980
+
981
+ outputs = self.bert(input_ids,
982
+ attention_mask=attention_mask,
983
+ token_type_ids=token_type_ids,
984
+ position_ids=position_ids,
985
+ head_mask=head_mask)
986
+
987
+ sequence_output = outputs[0]
988
+
989
+ sequence_output = self.dropout(sequence_output)
990
+ logits = self.classifier(sequence_output)
991
+
992
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
993
+ if labels is not None:
994
+ loss_fct = CrossEntropyLoss()
995
+ # Only keep active parts of the loss
996
+ if attention_mask is not None:
997
+ active_loss = attention_mask.view(-1) == 1
998
+ active_logits = logits.view(-1, self.num_labels)[active_loss]
999
+ active_labels = labels.view(-1)[active_loss]
1000
+ loss = loss_fct(active_logits, active_labels)
1001
+ else:
1002
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1003
+ outputs = (loss,) + outputs
1004
+
1005
+ return outputs # (loss), scores, (hidden_states), (attentions)
1006
+
1007
+
1008
+ @add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
1009
+ the hidden-states output to compute `span start logits` and `span end logits`). """,
1010
+ ALBERT_START_DOCSTRING, ALBERT_INPUTS_DOCSTRING)
1011
+ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
1012
+ r"""
1013
+ **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
1014
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1015
+ Positions are clamped to the length of the sequence (`sequence_length`).
1016
+ Position outside of the sequence are not taken into account for computing the loss.
1017
+ **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
1018
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1019
+ Positions are clamped to the length of the sequence (`sequence_length`).
1020
+ Position outside of the sequence are not taken into account for computing the loss.
1021
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
1022
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
1023
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
1024
+ **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
1025
+ Span-start scores (before SoftMax).
1026
+ **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
1027
+ Span-end scores (before SoftMax).
1028
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
1029
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
1030
+ of shape ``(batch_size, sequence_length, hidden_size)``:
1031
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1032
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
1033
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
1034
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
1035
+ Examples::
1036
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1037
+ model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
1038
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
1039
+ start_positions = torch.tensor([1])
1040
+ end_positions = torch.tensor([3])
1041
+ outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
1042
+ loss, start_scores, end_scores = outputs[:2]
1043
+ """
1044
+
1045
+ def __init__(self, config):
1046
+ super(AlbertForQuestionAnswering, self).__init__(config)
1047
+ self.num_labels = config.num_labels
1048
+
1049
+ self.bert = AlbertModel(config)
1050
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1051
+
1052
+ self.init_weights()
1053
+
1054
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
1055
+ start_positions=None, end_positions=None):
1056
+
1057
+ outputs = self.bert(input_ids,
1058
+ attention_mask=attention_mask,
1059
+ token_type_ids=token_type_ids,
1060
+ position_ids=position_ids,
1061
+ head_mask=head_mask)
1062
+
1063
+ sequence_output = outputs[0]
1064
+
1065
+ logits = self.qa_outputs(sequence_output)
1066
+ start_logits, end_logits = logits.split(1, dim=-1)
1067
+ start_logits = start_logits.squeeze(-1)
1068
+ end_logits = end_logits.squeeze(-1)
1069
+
1070
+ outputs = (start_logits, end_logits,) + outputs[2:]
1071
+ if start_positions is not None and end_positions is not None:
1072
+ # If we are on multi-GPU, split add a dimension
1073
+ if len(start_positions.size()) > 1:
1074
+ start_positions = start_positions.squeeze(-1)
1075
+ if len(end_positions.size()) > 1:
1076
+ end_positions = end_positions.squeeze(-1)
1077
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1078
+ ignored_index = start_logits.size(1)
1079
+ start_positions.clamp_(0, ignored_index)
1080
+ end_positions.clamp_(0, ignored_index)
1081
+
1082
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1083
+ start_loss = loss_fct(start_logits, start_positions)
1084
+ end_loss = loss_fct(end_logits, end_positions)
1085
+ total_loss = (start_loss + end_loss) / 2
1086
+ outputs = (total_loss,) + outputs
1087
+
1088
+ return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
model/modeling_albert_bright.py ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch brightmart version ALBERT model. """
2
+ from __future__ import absolute_import, division, print_function, unicode_literals
3
+
4
+ import logging
5
+ import os
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import CrossEntropyLoss, MSELoss
10
+
11
+ from .modeling_utils import PreTrainedModel, prune_linear_layer
12
+ from .configuration_albert import AlbertConfig
13
+ from .file_utils import add_start_docstrings
14
+ from .modeling_bert import (ACT2FN,
15
+ BertSelfAttention,
16
+ BertIntermediate,
17
+ BertPooler,
18
+ BertPredictionHeadTransform)
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
23
+ 'albert-base': "",
24
+ 'albert-large': "",
25
+ 'albert-xlarge': "",
26
+ 'albert-xxlarge': "",
27
+ }
28
+ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
29
+ """ Load tf checkpoints in a pytorch model.
30
+ """
31
+ try:
32
+ import re
33
+ import numpy as np
34
+ import tensorflow as tf
35
+ except ImportError:
36
+ logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
37
+ "https://www.tensorflow.org/install/ for installation instructions.")
38
+ raise
39
+ tf_path = os.path.abspath(tf_checkpoint_path)
40
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
41
+ # Load weights from TF model
42
+ init_vars = tf.train.list_variables(tf_path)
43
+ names = []
44
+ arrays = []
45
+ for name, shape in init_vars:
46
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
47
+ array = tf.train.load_variable(tf_path, name)
48
+ names.append(name)
49
+ arrays.append(array)
50
+ for name, array in zip(names, arrays):
51
+ name = name.split('/')
52
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
53
+ # which are not required for using pretrained model
54
+ if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
55
+ logger.info("Skipping {}".format("/".join(name)))
56
+ continue
57
+ pointer = model
58
+ for m_name in name:
59
+ if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
60
+ l = re.split(r'_(\d+)', m_name)
61
+ else:
62
+ l = [m_name]
63
+ if l[0] == 'kernel' or l[0] == 'gamma':
64
+ pointer = getattr(pointer, 'weight')
65
+ elif l[0] == 'output_bias' or l[0] == 'beta':
66
+ pointer = getattr(pointer, 'bias')
67
+ elif l[0] == 'output_weights':
68
+ pointer = getattr(pointer, 'weight')
69
+ elif l[0] == 'squad':
70
+ pointer = getattr(pointer, 'classifier')
71
+ else:
72
+ try:
73
+ pointer = getattr(pointer, l[0])
74
+ except AttributeError:
75
+ logger.info("Skipping {}".format("/".join(name)))
76
+ continue
77
+ if len(l) >= 2:
78
+ num = int(l[1])
79
+ pointer = pointer[num]
80
+ if m_name[-11:] == '_embeddings':
81
+ pointer = getattr(pointer, 'weight')
82
+ elif m_name[-13:] == '_embeddings_2':
83
+ pointer = getattr(pointer, 'weight')
84
+ array = np.transpose(array)
85
+ elif m_name == 'kernel':
86
+ array = np.transpose(array)
87
+ try:
88
+ assert pointer.shape == array.shape
89
+ except AssertionError as e:
90
+ e.args += (pointer.shape, array.shape)
91
+ raise
92
+ logger.info("Initialize PyTorch weight {}".format(name))
93
+ pointer.data = torch.from_numpy(array)
94
+ return model
95
+
96
+ AlbertLayerNorm = torch.nn.LayerNorm
97
+ class AlbertEmbeddings(nn.Module):
98
+ """Construct the embeddings from word, position and token_type embeddings.
99
+ """
100
+ def __init__(self, config):
101
+ super(AlbertEmbeddings, self).__init__()
102
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=0)
103
+ # project layer
104
+ self.word_embeddings_2 = nn.Linear(config.embedding_size, config.hidden_size, bias=False)
105
+
106
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
107
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
108
+
109
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
110
+ # any TensorFlow checkpoint file
111
+ self.LayerNorm =AlbertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
112
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
113
+
114
+ def forward(self, input_ids, token_type_ids=None, position_ids=None):
115
+ seq_length = input_ids.size(1)
116
+ if position_ids is None:
117
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
118
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
119
+ if token_type_ids is None:
120
+ token_type_ids = torch.zeros_like(input_ids)
121
+
122
+ words_embeddings = self.word_embeddings(input_ids)
123
+ # project transform
124
+ words_embeddings = self.word_embeddings_2(words_embeddings)
125
+ position_embeddings = self.position_embeddings(position_ids)
126
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
127
+
128
+ embeddings = words_embeddings + position_embeddings + token_type_embeddings
129
+ embeddings = self.LayerNorm(embeddings)
130
+ embeddings = self.dropout(embeddings)
131
+ return embeddings
132
+
133
+ class AlbertSelfOutput(nn.Module):
134
+ def __init__(self, config):
135
+ super(AlbertSelfOutput, self).__init__()
136
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
137
+ self.LayerNorm = AlbertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
138
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
139
+
140
+ def forward(self, hidden_states, input_tensor):
141
+ hidden_states = self.dense(hidden_states)
142
+ hidden_states = self.dropout(hidden_states)
143
+ # postln
144
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
145
+ return hidden_states
146
+
147
+
148
+ class AlbertAttention(nn.Module):
149
+ def __init__(self, config):
150
+ super(AlbertAttention, self).__init__()
151
+ self.self = BertSelfAttention(config)
152
+ self.output = AlbertSelfOutput(config)
153
+ self.pruned_heads = set()
154
+
155
+ def prune_heads(self, heads):
156
+ if len(heads) == 0:
157
+ return
158
+ mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
159
+ heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
160
+ for head in heads:
161
+ # Compute how many pruned heads are before the head and move the index accordingly
162
+ head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
163
+ mask[head] = 0
164
+ mask = mask.view(-1).contiguous().eq(1)
165
+ index = torch.arange(len(mask))[mask].long()
166
+
167
+ # Prune linear layers
168
+ self.self.query = prune_linear_layer(self.self.query, index)
169
+ self.self.key = prune_linear_layer(self.self.key, index)
170
+ self.self.value = prune_linear_layer(self.self.value, index)
171
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
172
+
173
+ # Update hyper params and store pruned heads
174
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
175
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
176
+ self.pruned_heads = self.pruned_heads.union(heads)
177
+
178
+ def forward(self, input_tensor, attention_mask=None, head_mask=None):
179
+ # postln
180
+ self_outputs = self.self(input_tensor, attention_mask, head_mask)
181
+ attention_output = self.output(self_outputs[0], input_tensor)
182
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
183
+ return outputs
184
+
185
+ class AlbertOutput(nn.Module):
186
+ def __init__(self, config):
187
+ super(AlbertOutput, self).__init__()
188
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
189
+ self.LayerNorm = AlbertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
190
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
191
+
192
+ def forward(self, hidden_states, input_tensor):
193
+ hidden_states = self.dense(hidden_states)
194
+ hidden_states = self.dropout(hidden_states)
195
+ # postln
196
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
197
+ return hidden_states
198
+
199
+ class BertLayer(nn.Module):
200
+ def __init__(self, config):
201
+ super(BertLayer, self).__init__()
202
+ self.attention = AlbertAttention(config)
203
+ self.intermediate = BertIntermediate(config)
204
+ self.output = AlbertOutput(config)
205
+
206
+ def forward(self, hidden_states, attention_mask=None, head_mask=None):
207
+ attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
208
+ attention_output = attention_outputs[0]
209
+ # postln
210
+ attention_output_pre = attention_output
211
+ intermediate_output = self.intermediate(attention_output_pre)
212
+ layer_output = self.output(intermediate_output, attention_output)
213
+ outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
214
+ return outputs
215
+
216
+ class AlbertEncoder(nn.Module):
217
+ def __init__(self, config):
218
+ super(AlbertEncoder, self).__init__()
219
+ self.output_attentions = config.output_attentions
220
+ self.output_hidden_states = config.output_hidden_states
221
+ self.num_hidden_layers = config.num_hidden_layers
222
+ self.layer_shared = BertLayer(config)
223
+
224
+ def forward(self, hidden_states, attention_mask=None, head_mask=None):
225
+ all_hidden_states = ()
226
+ all_attentions = ()
227
+ for i in range(self.num_hidden_layers):
228
+ layer_module = self.layer_shared
229
+ if self.output_hidden_states:
230
+ all_hidden_states = all_hidden_states + (hidden_states,)
231
+ layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
232
+ hidden_states = layer_outputs[0]
233
+
234
+ if self.output_attentions:
235
+ all_attentions = all_attentions + (layer_outputs[1],)
236
+ # Add last layer
237
+ if self.output_hidden_states:
238
+ all_hidden_states = all_hidden_states + (hidden_states,)
239
+ outputs = (hidden_states,)
240
+ if self.output_hidden_states:
241
+ outputs = outputs + (all_hidden_states,)
242
+ if self.output_attentions:
243
+ outputs = outputs + (all_attentions,)
244
+ return outputs # last-layer hidden state, (all hidden states), (all attentions)
245
+
246
+ class AlbertLMPredictionHead(nn.Module):
247
+ def __init__(self, config):
248
+ super(AlbertLMPredictionHead, self).__init__()
249
+ self.transform = BertPredictionHeadTransform(config)
250
+ # The output weights are the same as the input embeddings, but there is
251
+ # an output-only bias for each token.
252
+ self.project_layer = nn.Linear(config.hidden_size, config.embedding_size, bias=False)
253
+ self.decoder = nn.Linear(config.embedding_size,
254
+ config.vocab_size,
255
+ bias=False)
256
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
257
+
258
+ def forward(self, hidden_states):
259
+ hidden_states = self.transform(hidden_states)
260
+ hidden_states = self.project_layer(hidden_states)
261
+ hidden_states = self.decoder(hidden_states) + self.bias
262
+ return hidden_states
263
+
264
+ class AlbertOnlyMLMHead(nn.Module):
265
+ def __init__(self, config):
266
+ super(AlbertOnlyMLMHead, self).__init__()
267
+ self.predictions = AlbertLMPredictionHead(config)
268
+
269
+ def forward(self, sequence_output):
270
+ prediction_scores = self.predictions(sequence_output)
271
+ return prediction_scores
272
+
273
+ class AlbertOnlyNSPHead(nn.Module):
274
+ def __init__(self, config):
275
+ super(AlbertOnlyNSPHead, self).__init__()
276
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
277
+
278
+ def forward(self, pooled_output):
279
+ seq_relationship_score = self.seq_relationship(pooled_output)
280
+ return seq_relationship_score
281
+
282
+ class AlbertPreTrainingHeads(nn.Module):
283
+ def __init__(self, config):
284
+ super(AlbertPreTrainingHeads, self).__init__()
285
+ self.predictions = AlbertLMPredictionHead(config)
286
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
287
+
288
+ def forward(self, sequence_output, pooled_output):
289
+ prediction_scores = self.predictions(sequence_output)
290
+ seq_relationship_score = self.seq_relationship(pooled_output)
291
+ return prediction_scores, seq_relationship_score
292
+
293
+ class AlbertPreTrainedModel(PreTrainedModel):
294
+ """ An abstract class to handle weights initialization and
295
+ a simple interface for dowloading and loading pretrained models.
296
+ """
297
+ config_class = AlbertConfig
298
+ pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
299
+ load_tf_weights = load_tf_weights_in_albert
300
+ base_model_prefix = "bert"
301
+
302
+ def _init_weights(self, module):
303
+ """ Initialize the weights """
304
+ if isinstance(module, (nn.Linear, nn.Embedding)):
305
+ # Slightly different from the TF version which uses truncated_normal for initialization
306
+ # cf https://github.com/pytorch/pytorch/pull/5617
307
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
308
+ elif isinstance(module, AlbertLayerNorm):
309
+ module.bias.data.zero_()
310
+ module.weight.data.fill_(1.0)
311
+ if isinstance(module, nn.Linear) and module.bias is not None:
312
+ module.bias.data.zero_()
313
+
314
+ BERT_START_DOCSTRING = r""" The BERT model was proposed in
315
+ `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_
316
+ by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional transformer
317
+ pre-trained using a combination of masked language modeling objective and next sentence prediction
318
+ on a large corpus comprising the Toronto Book Corpus and Wikipedia.
319
+
320
+ This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
321
+ refer to the PyTorch documentation for all matter related to general usage and behavior.
322
+
323
+ .. _`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`:
324
+ https://arxiv.org/abs/1810.04805
325
+
326
+ .. _`torch.nn.Module`:
327
+ https://pytorch.org/docs/stable/nn.html#module
328
+
329
+ Parameters:
330
+ config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
331
+ Initializing with a config file does not load the weights associated with the model, only the configuration.
332
+ Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
333
+ """
334
+
335
+ BERT_INPUTS_DOCSTRING = r"""
336
+ Inputs:
337
+ **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
338
+ Indices of input sequence tokens in the vocabulary.
339
+ To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:
340
+
341
+ (a) For sequence pairs:
342
+
343
+ ``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
344
+
345
+ ``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
346
+
347
+ (b) For single sequences:
348
+
349
+ ``tokens: [CLS] the dog is hairy . [SEP]``
350
+
351
+ ``token_type_ids: 0 0 0 0 0 0 0``
352
+
353
+ Bert is a model with absolute position embeddings so it's usually advised to pad the inputs on
354
+ the right rather than the left.
355
+
356
+ Indices can be obtained using :class:`transformers.BertTokenizer`.
357
+ See :func:`transformers.PreTrainedTokenizer.encode` and
358
+ :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
359
+ **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
360
+ Mask to avoid performing attention on padding token indices.
361
+ Mask values selected in ``[0, 1]``:
362
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
363
+ **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
364
+ Segment token indices to indicate first and second portions of the inputs.
365
+ Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
366
+ corresponds to a `sentence B` token
367
+ (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
368
+ **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
369
+ Indices of positions of each input sequence tokens in the position embeddings.
370
+ Selected in the range ``[0, config.max_position_embeddings - 1]``.
371
+ **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
372
+ Mask to nullify selected heads of the self-attention modules.
373
+ Mask values selected in ``[0, 1]``:
374
+ ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
375
+ """
376
+
377
+
378
+ @add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
379
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
380
+ class AlbertModel(AlbertPreTrainedModel):
381
+ r"""
382
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
383
+ **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
384
+ Sequence of hidden-states at the output of the last layer of the model.
385
+ **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
386
+ Last layer hidden-state of the first token of the sequence (classification token)
387
+ further processed by a Linear layer and a Tanh activation function. The Linear
388
+ layer weights are trained from the next sentence prediction (classification)
389
+ objective during Bert pretraining. This output is usually *not* a good summary
390
+ of the semantic content of the input, you're often better with averaging or pooling
391
+ the sequence of hidden-states for the whole input sequence.
392
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
393
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
394
+ of shape ``(batch_size, sequence_length, hidden_size)``:
395
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
396
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
397
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
398
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
399
+
400
+ Examples::
401
+
402
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
403
+ model = BertModel.from_pretrained('bert-base-uncased')
404
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
405
+ outputs = model(input_ids)
406
+ last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
407
+
408
+ """
409
+
410
+ def __init__(self, config):
411
+ super(AlbertModel, self).__init__(config)
412
+
413
+ self.embeddings = AlbertEmbeddings(config)
414
+ self.encoder = AlbertEncoder(config)
415
+ self.pooler = BertPooler(config)
416
+
417
+ self.init_weights()
418
+
419
+ def _resize_token_embeddings(self, new_num_tokens):
420
+ old_embeddings = self.embeddings.word_embeddings
421
+ new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
422
+ self.embeddings.word_embeddings = new_embeddings
423
+ return self.embeddings.word_embeddings
424
+
425
+ def _prune_heads(self, heads_to_prune):
426
+ """ Prunes heads of the model.
427
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
428
+ See base class PreTrainedModel
429
+ """
430
+ for layer, heads in heads_to_prune.items():
431
+ self.encoder.layer[layer].attention.prune_heads(heads)
432
+
433
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
434
+ if attention_mask is None:
435
+ attention_mask = torch.ones_like(input_ids)
436
+ if token_type_ids is None:
437
+ token_type_ids = torch.zeros_like(input_ids)
438
+
439
+ # We create a 3D attention mask from a 2D tensor mask.
440
+ # Sizes are [batch_size, 1, 1, to_seq_length]
441
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
442
+ # this attention mask is more simple than the triangular masking of causal attention
443
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
444
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
445
+
446
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
447
+ # masked positions, this operation will create a tensor which is 0.0 for
448
+ # positions we want to attend and -10000.0 for masked positions.
449
+ # Since we are adding it to the raw scores before the softmax, this is
450
+ # effectively the same as removing these entirely.
451
+ extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
452
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
453
+
454
+ # Prepare head mask if needed
455
+ # 1.0 in head_mask indicate we keep the head
456
+ # attention_probs has shape bsz x n_heads x N x N
457
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
458
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
459
+ if head_mask is not None:
460
+ if head_mask.dim() == 1:
461
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
462
+ head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
463
+ elif head_mask.dim() == 2:
464
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(
465
+ -1) # We can specify head_mask for each layer
466
+ head_mask = head_mask.to(
467
+ dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
468
+ else:
469
+ head_mask = [None] * self.config.num_hidden_layers
470
+
471
+ embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
472
+ encoder_outputs = self.encoder(embedding_output,
473
+ extended_attention_mask,
474
+ head_mask=head_mask)
475
+ sequence_output = encoder_outputs[0]
476
+ pooled_output = self.pooler(sequence_output)
477
+
478
+ outputs = (sequence_output, pooled_output,) + encoder_outputs[
479
+ 1:] # add hidden_states and attentions if they are here
480
+ return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
481
+
482
+
483
+ @add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
484
+ a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
485
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
486
+ class AlbertForPreTraining(AlbertPreTrainedModel):
487
+ r"""
488
+ **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
489
+ Labels for computing the masked language modeling loss.
490
+ Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
491
+ Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
492
+ in ``[0, ..., config.vocab_size]``
493
+ **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
494
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
495
+ Indices should be in ``[0, 1]``.
496
+ ``0`` indicates sequence B is a continuation of sequence A,
497
+ ``1`` indicates sequence B is a random sequence.
498
+
499
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
500
+ **loss**: (`optional`, returned when both ``masked_lm_labels`` and ``next_sentence_label`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
501
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
502
+ **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
503
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
504
+ **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
505
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
506
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
507
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
508
+ of shape ``(batch_size, sequence_length, hidden_size)``:
509
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
510
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
511
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
512
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
513
+
514
+ Examples::
515
+
516
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
517
+ model = BertForPreTraining.from_pretrained('bert-base-uncased')
518
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
519
+ outputs = model(input_ids)
520
+ prediction_scores, seq_relationship_scores = outputs[:2]
521
+
522
+ """
523
+
524
+ def __init__(self, config):
525
+ super(AlbertForPreTraining, self).__init__(config)
526
+
527
+ self.bert = AlbertModel(config)
528
+ self.cls = AlbertPreTrainingHeads(config)
529
+
530
+ self.init_weights()
531
+ self.tie_weights()
532
+
533
+ def tie_weights(self):
534
+ """ Make sure we are sharing the input and output embeddings.
535
+ Export to TorchScript can't handle parameter sharing so we are cloning them instead.
536
+ """
537
+ self._tie_or_clone_weights(self.cls.predictions.decoder,
538
+ self.bert.embeddings.word_embeddings)
539
+
540
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
541
+ masked_lm_labels=None, next_sentence_label=None):
542
+ outputs = self.bert(input_ids,
543
+ attention_mask=attention_mask,
544
+ token_type_ids=token_type_ids,
545
+ position_ids=position_ids,
546
+ head_mask=head_mask)
547
+
548
+ sequence_output, pooled_output = outputs[:2]
549
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
550
+
551
+ outputs = (prediction_scores, seq_relationship_score,) + outputs[
552
+ 2:] # add hidden states and attention if they are here
553
+
554
+ if masked_lm_labels is not None and next_sentence_label is not None:
555
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
556
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
557
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
558
+ total_loss = masked_lm_loss + next_sentence_loss
559
+ outputs = (total_loss,) + outputs
560
+
561
+ return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
562
+
563
+
564
+ @add_start_docstrings("""Bert Model with a `language modeling` head on top. """,
565
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
566
+ class AlbertForMaskedLM(AlbertPreTrainedModel):
567
+ r"""
568
+ **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
569
+ Labels for computing the masked language modeling loss.
570
+ Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
571
+ Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
572
+ in ``[0, ..., config.vocab_size]``
573
+
574
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
575
+ **loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
576
+ Masked language modeling loss.
577
+ **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
578
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
579
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
580
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
581
+ of shape ``(batch_size, sequence_length, hidden_size)``:
582
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
583
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
584
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
585
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
586
+
587
+ Examples::
588
+
589
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
590
+ model = BertForMaskedLM.from_pretrained('bert-base-uncased')
591
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
592
+ outputs = model(input_ids, masked_lm_labels=input_ids)
593
+ loss, prediction_scores = outputs[:2]
594
+
595
+ """
596
+
597
+ def __init__(self, config):
598
+ super(AlbertForMaskedLM, self).__init__(config)
599
+
600
+ self.bert = AlbertModel(config)
601
+ self.cls = AlbertOnlyMLMHead(config)
602
+
603
+ self.init_weights()
604
+ self.tie_weights()
605
+
606
+ def tie_weights(self):
607
+ """ Make sure we are sharing the input and output embeddings.
608
+ Export to TorchScript can't handle parameter sharing so we are cloning them instead.
609
+ """
610
+ self._tie_or_clone_weights(self.cls.predictions.decoder,
611
+ self.bert.embeddings.word_embeddings)
612
+
613
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
614
+ masked_lm_labels=None):
615
+ outputs = self.bert(input_ids,
616
+ attention_mask=attention_mask,
617
+ token_type_ids=token_type_ids,
618
+ position_ids=position_ids,
619
+ head_mask=head_mask)
620
+
621
+ sequence_output = outputs[0]
622
+ prediction_scores = self.cls(sequence_output)
623
+
624
+ outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
625
+ if masked_lm_labels is not None:
626
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
627
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
628
+ outputs = (masked_lm_loss,) + outputs
629
+
630
+ return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
631
+
632
+
633
+ @add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
634
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
635
+ class AlbertForNextSentencePrediction(AlbertPreTrainedModel):
636
+ r"""
637
+ **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
638
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
639
+ Indices should be in ``[0, 1]``.
640
+ ``0`` indicates sequence B is a continuation of sequence A,
641
+ ``1`` indicates sequence B is a random sequence.
642
+
643
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
644
+ **loss**: (`optional`, returned when ``next_sentence_label`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
645
+ Next sequence prediction (classification) loss.
646
+ **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
647
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
648
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
649
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
650
+ of shape ``(batch_size, sequence_length, hidden_size)``:
651
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
652
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
653
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
654
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
655
+
656
+ Examples::
657
+
658
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
659
+ model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
660
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
661
+ outputs = model(input_ids)
662
+ seq_relationship_scores = outputs[0]
663
+
664
+ """
665
+
666
+ def __init__(self, config):
667
+ super(AlbertForNextSentencePrediction, self).__init__(config)
668
+
669
+ self.bert = AlbertModel(config)
670
+ self.cls = AlbertOnlyNSPHead(config)
671
+
672
+ self.init_weights()
673
+
674
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
675
+ next_sentence_label=None):
676
+ outputs = self.bert(input_ids,
677
+ attention_mask=attention_mask,
678
+ token_type_ids=token_type_ids,
679
+ position_ids=position_ids,
680
+ head_mask=head_mask)
681
+
682
+ pooled_output = outputs[1]
683
+
684
+ seq_relationship_score = self.cls(pooled_output)
685
+
686
+ outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
687
+ if next_sentence_label is not None:
688
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
689
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
690
+ outputs = (next_sentence_loss,) + outputs
691
+
692
+ return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
693
+
694
+
695
+ @add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
696
+ the pooled output) e.g. for GLUE tasks. """,
697
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
698
+ class AlbertForSequenceClassification(AlbertPreTrainedModel):
699
+ r"""
700
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
701
+ Labels for computing the sequence classification/regression loss.
702
+ Indices should be in ``[0, ..., config.num_labels - 1]``.
703
+ If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
704
+ If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
705
+
706
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
707
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
708
+ Classification (or regression if config.num_labels==1) loss.
709
+ **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
710
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
711
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
712
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
713
+ of shape ``(batch_size, sequence_length, hidden_size)``:
714
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
715
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
716
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
717
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
718
+
719
+ Examples::
720
+
721
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
722
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
723
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
724
+ labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
725
+ outputs = model(input_ids, labels=labels)
726
+ loss, logits = outputs[:2]
727
+
728
+ """
729
+
730
+ def __init__(self, config):
731
+ super(AlbertForSequenceClassification, self).__init__(config)
732
+ self.num_labels = config.num_labels
733
+
734
+ self.bert = AlbertModel(config)
735
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
736
+ self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
737
+
738
+ self.init_weights()
739
+
740
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None,
741
+ position_ids=None, head_mask=None, labels=None):
742
+
743
+ outputs = self.bert(input_ids,
744
+ attention_mask=attention_mask,
745
+ token_type_ids=token_type_ids,
746
+ position_ids=position_ids,
747
+ head_mask=head_mask)
748
+
749
+ pooled_output = outputs[1]
750
+
751
+ pooled_output = self.dropout(pooled_output)
752
+ logits = self.classifier(pooled_output)
753
+
754
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
755
+
756
+ if labels is not None:
757
+ if self.num_labels == 1:
758
+ # We are doing regression
759
+ loss_fct = MSELoss()
760
+ loss = loss_fct(logits.view(-1), labels.view(-1))
761
+ else:
762
+ loss_fct = CrossEntropyLoss()
763
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
764
+ outputs = (loss,) + outputs
765
+
766
+ return outputs # (loss), logits, (hidden_states), (attentions)
767
+
768
+
769
+ @add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
770
+ the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
771
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
772
+ class AlbertForMultipleChoice(AlbertPreTrainedModel):
773
+ r"""
774
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
775
+ Labels for computing the multiple choice classification loss.
776
+ Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
777
+ of the input tensors. (see `input_ids` above)
778
+
779
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
780
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
781
+ Classification loss.
782
+ **classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
783
+ of the input tensors. (see `input_ids` above).
784
+ Classification scores (before SoftMax).
785
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
786
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
787
+ of shape ``(batch_size, sequence_length, hidden_size)``:
788
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
789
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
790
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
791
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
792
+
793
+ Examples::
794
+
795
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
796
+ model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
797
+ choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
798
+ input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
799
+ labels = torch.tensor(1).unsqueeze(0) # Batch size 1
800
+ outputs = model(input_ids, labels=labels)
801
+ loss, classification_scores = outputs[:2]
802
+
803
+ """
804
+
805
+ def __init__(self, config):
806
+ super(AlbertForMultipleChoice, self).__init__(config)
807
+
808
+ self.bert = AlbertModel(config)
809
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
810
+ self.classifier = nn.Linear(config.hidden_size, 1)
811
+
812
+ self.init_weights()
813
+
814
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None,
815
+ position_ids=None, head_mask=None, labels=None):
816
+ num_choices = input_ids.shape[1]
817
+
818
+ input_ids = input_ids.view(-1, input_ids.size(-1))
819
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
820
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
821
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
822
+
823
+ outputs = self.bert(input_ids,
824
+ attention_mask=attention_mask,
825
+ token_type_ids=token_type_ids,
826
+ position_ids=position_ids,
827
+ head_mask=head_mask)
828
+
829
+ pooled_output = outputs[1]
830
+
831
+ pooled_output = self.dropout(pooled_output)
832
+ logits = self.classifier(pooled_output)
833
+ reshaped_logits = logits.view(-1, num_choices)
834
+
835
+ outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
836
+
837
+ if labels is not None:
838
+ loss_fct = CrossEntropyLoss()
839
+ loss = loss_fct(reshaped_logits, labels)
840
+ outputs = (loss,) + outputs
841
+
842
+ return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
843
+
844
+
845
+ @add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
846
+ the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
847
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
848
+ class AlbertForTokenClassification(AlbertPreTrainedModel):
849
+ r"""
850
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
851
+ Labels for computing the token classification loss.
852
+ Indices should be in ``[0, ..., config.num_labels - 1]``.
853
+
854
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
855
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
856
+ Classification loss.
857
+ **scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
858
+ Classification scores (before SoftMax).
859
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
860
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
861
+ of shape ``(batch_size, sequence_length, hidden_size)``:
862
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
863
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
864
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
865
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
866
+
867
+ Examples::
868
+
869
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
870
+ model = BertForTokenClassification.from_pretrained('bert-base-uncased')
871
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
872
+ labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
873
+ outputs = model(input_ids, labels=labels)
874
+ loss, scores = outputs[:2]
875
+
876
+ """
877
+
878
+ def __init__(self, config):
879
+ super(AlbertForTokenClassification, self).__init__(config)
880
+ self.num_labels = config.num_labels
881
+
882
+ self.bert = AlbertModel(config)
883
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
884
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
885
+
886
+ self.init_weights()
887
+
888
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None,
889
+ position_ids=None, head_mask=None, labels=None):
890
+
891
+ outputs = self.bert(input_ids,
892
+ attention_mask=attention_mask,
893
+ token_type_ids=token_type_ids,
894
+ position_ids=position_ids,
895
+ head_mask=head_mask)
896
+
897
+ sequence_output = outputs[0]
898
+
899
+ sequence_output = self.dropout(sequence_output)
900
+ logits = self.classifier(sequence_output)
901
+
902
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
903
+ if labels is not None:
904
+ loss_fct = CrossEntropyLoss()
905
+ # Only keep active parts of the loss
906
+ if attention_mask is not None:
907
+ active_loss = attention_mask.view(-1) == 1
908
+ active_logits = logits.view(-1, self.num_labels)[active_loss]
909
+ active_labels = labels.view(-1)[active_loss]
910
+ loss = loss_fct(active_logits, active_labels)
911
+ else:
912
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
913
+ outputs = (loss,) + outputs
914
+
915
+ return outputs # (loss), scores, (hidden_states), (attentions)
916
+
917
+
918
+ @add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
919
+ the hidden-states output to compute `span start logits` and `span end logits`). """,
920
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
921
+ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
922
+ r"""
923
+ **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
924
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
925
+ Positions are clamped to the length of the sequence (`sequence_length`).
926
+ Position outside of the sequence are not taken into account for computing the loss.
927
+ **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
928
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
929
+ Positions are clamped to the length of the sequence (`sequence_length`).
930
+ Position outside of the sequence are not taken into account for computing the loss.
931
+
932
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
933
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
934
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
935
+ **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
936
+ Span-start scores (before SoftMax).
937
+ **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
938
+ Span-end scores (before SoftMax).
939
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
940
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
941
+ of shape ``(batch_size, sequence_length, hidden_size)``:
942
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
943
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
944
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
945
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
946
+
947
+ Examples::
948
+
949
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
950
+ model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
951
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
952
+ start_positions = torch.tensor([1])
953
+ end_positions = torch.tensor([3])
954
+ outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
955
+ loss, start_scores, end_scores = outputs[:2]
956
+
957
+ """
958
+
959
+ def __init__(self, config):
960
+ super(AlbertForQuestionAnswering, self).__init__(config)
961
+ self.num_labels = config.num_labels
962
+
963
+ self.bert = AlbertModel(config)
964
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
965
+
966
+ self.init_weights()
967
+
968
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
969
+ start_positions=None, end_positions=None):
970
+
971
+ outputs = self.bert(input_ids,
972
+ attention_mask=attention_mask,
973
+ token_type_ids=token_type_ids,
974
+ position_ids=position_ids,
975
+ head_mask=head_mask)
976
+
977
+ sequence_output = outputs[0]
978
+
979
+ logits = self.qa_outputs(sequence_output)
980
+ start_logits, end_logits = logits.split(1, dim=-1)
981
+ start_logits = start_logits.squeeze(-1)
982
+ end_logits = end_logits.squeeze(-1)
983
+
984
+ outputs = (start_logits, end_logits,) + outputs[2:]
985
+ if start_positions is not None and end_positions is not None:
986
+ # If we are on multi-GPU, split add a dimension
987
+ if len(start_positions.size()) > 1:
988
+ start_positions = start_positions.squeeze(-1)
989
+ if len(end_positions.size()) > 1:
990
+ end_positions = end_positions.squeeze(-1)
991
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
992
+ ignored_index = start_logits.size(1)
993
+ start_positions.clamp_(0, ignored_index)
994
+ end_positions.clamp_(0, ignored_index)
995
+
996
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
997
+ start_loss = loss_fct(start_logits, start_positions)
998
+ end_loss = loss_fct(end_logits, end_positions)
999
+ total_loss = (start_loss + end_loss) / 2
1000
+ outputs = (total_loss,) + outputs
1001
+
1002
+ return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
model/modeling_bert.py ADDED
@@ -0,0 +1,1149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model. """
17
+
18
+ from __future__ import absolute_import, division, print_function, unicode_literals
19
+
20
+ import json
21
+ import logging
22
+ import math
23
+ import os
24
+ import sys
25
+ from io import open
26
+
27
+ import torch
28
+ from torch import nn
29
+ from torch.nn import CrossEntropyLoss, MSELoss
30
+
31
+ from .modeling_utils import PreTrainedModel, prune_linear_layer
32
+ from .configuration_bert import BertConfig
33
+ from .file_utils import add_start_docstrings
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
38
+ 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
39
+ 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
40
+ 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
41
+ 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
42
+ 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
43
+ 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
44
+ 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
45
+ 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
46
+ 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
47
+ 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
48
+ 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
49
+ 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
50
+ 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
51
+ 'bert-base-german-dbmdz-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-pytorch_model.bin",
52
+ 'bert-base-german-dbmdz-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-pytorch_model.bin",
53
+ }
54
+
55
+ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
56
+ """ Load tf checkpoints in a pytorch model.
57
+ """
58
+ try:
59
+ import re
60
+ import numpy as np
61
+ import tensorflow as tf
62
+ except ImportError:
63
+ logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
64
+ "https://www.tensorflow.org/install/ for installation instructions.")
65
+ raise
66
+ tf_path = os.path.abspath(tf_checkpoint_path)
67
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
68
+ # Load weights from TF model
69
+ init_vars = tf.train.list_variables(tf_path)
70
+ names = []
71
+ arrays = []
72
+ for name, shape in init_vars:
73
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
74
+ array = tf.train.load_variable(tf_path, name)
75
+ names.append(name)
76
+ arrays.append(array)
77
+
78
+ for name, array in zip(names, arrays):
79
+ name = name.split('/')
80
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
81
+ # which are not required for using pretrained model
82
+ if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
83
+ logger.info("Skipping {}".format("/".join(name)))
84
+ continue
85
+ pointer = model
86
+ for m_name in name:
87
+ if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
88
+ l = re.split(r'_(\d+)', m_name)
89
+ else:
90
+ l = [m_name]
91
+ if l[0] == 'kernel' or l[0] == 'gamma':
92
+ pointer = getattr(pointer, 'weight')
93
+ elif l[0] == 'output_bias' or l[0] == 'beta':
94
+ pointer = getattr(pointer, 'bias')
95
+ elif l[0] == 'output_weights':
96
+ pointer = getattr(pointer, 'weight')
97
+ elif l[0] == 'squad':
98
+ pointer = getattr(pointer, 'classifier')
99
+ else:
100
+ try:
101
+ pointer = getattr(pointer, l[0])
102
+ except AttributeError:
103
+ logger.info("Skipping {}".format("/".join(name)))
104
+ continue
105
+ if len(l) >= 2:
106
+ num = int(l[1])
107
+ pointer = pointer[num]
108
+ if m_name[-11:] == '_embeddings':
109
+ pointer = getattr(pointer, 'weight')
110
+ elif m_name == 'kernel':
111
+ array = np.transpose(array)
112
+ try:
113
+ assert pointer.shape == array.shape
114
+ except AssertionError as e:
115
+ e.args += (pointer.shape, array.shape)
116
+ raise
117
+ logger.info("Initialize PyTorch weight {}".format(name))
118
+ pointer.data = torch.from_numpy(array)
119
+ return model
120
+
121
+
122
+ def gelu(x):
123
+ """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
124
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
125
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
126
+ Also see https://arxiv.org/abs/1606.08415
127
+ """
128
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
129
+
130
+ def gelu_new(x):
131
+ """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
132
+ Also see https://arxiv.org/abs/1606.08415
133
+ """
134
+ return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
135
+
136
+ def swish(x):
137
+ return x * torch.sigmoid(x)
138
+
139
+
140
+ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new}
141
+
142
+
143
+ BertLayerNorm = torch.nn.LayerNorm
144
+
145
+ class BertEmbeddings(nn.Module):
146
+ """Construct the embeddings from word, position and token_type embeddings.
147
+ """
148
+ def __init__(self, config):
149
+ super(BertEmbeddings, self).__init__()
150
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
151
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
152
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
153
+
154
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
155
+ # any TensorFlow checkpoint file
156
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
157
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
158
+
159
+ def forward(self, input_ids, token_type_ids=None, position_ids=None):
160
+ seq_length = input_ids.size(1)
161
+ if position_ids is None:
162
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
163
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
164
+ if token_type_ids is None:
165
+ token_type_ids = torch.zeros_like(input_ids)
166
+
167
+ words_embeddings = self.word_embeddings(input_ids)
168
+ position_embeddings = self.position_embeddings(position_ids)
169
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
170
+
171
+ embeddings = words_embeddings + position_embeddings + token_type_embeddings
172
+ embeddings = self.LayerNorm(embeddings)
173
+ embeddings = self.dropout(embeddings)
174
+ return embeddings
175
+
176
+
177
+ class BertSelfAttention(nn.Module):
178
+ def __init__(self, config):
179
+ super(BertSelfAttention, self).__init__()
180
+ if config.hidden_size % config.num_attention_heads != 0:
181
+ raise ValueError(
182
+ "The hidden size (%d) is not a multiple of the number of attention "
183
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
184
+ self.output_attentions = config.output_attentions
185
+
186
+ self.num_attention_heads = config.num_attention_heads
187
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
188
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
189
+
190
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
191
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
192
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
193
+
194
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
195
+
196
+ def transpose_for_scores(self, x):
197
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
198
+ x = x.view(*new_x_shape)
199
+ return x.permute(0, 2, 1, 3)
200
+
201
+ def forward(self, hidden_states, attention_mask=None, head_mask=None):
202
+ mixed_query_layer = self.query(hidden_states)
203
+ mixed_key_layer = self.key(hidden_states)
204
+ mixed_value_layer = self.value(hidden_states)
205
+
206
+ query_layer = self.transpose_for_scores(mixed_query_layer)
207
+ key_layer = self.transpose_for_scores(mixed_key_layer)
208
+ value_layer = self.transpose_for_scores(mixed_value_layer)
209
+
210
+ # Take the dot product between "query" and "key" to get the raw attention scores.
211
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
212
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
213
+ if attention_mask is not None:
214
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
215
+ attention_scores = attention_scores + attention_mask
216
+
217
+ # Normalize the attention scores to probabilities.
218
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
219
+
220
+ # This is actually dropping out entire tokens to attend to, which might
221
+ # seem a bit unusual, but is taken from the original Transformer paper.
222
+ attention_probs = self.dropout(attention_probs)
223
+
224
+ # Mask heads if we want to
225
+ if head_mask is not None:
226
+ attention_probs = attention_probs * head_mask
227
+
228
+ context_layer = torch.matmul(attention_probs, value_layer)
229
+
230
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
231
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
232
+ context_layer = context_layer.view(*new_context_layer_shape)
233
+
234
+ outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
235
+ return outputs
236
+
237
+
238
+ class BertSelfOutput(nn.Module):
239
+ def __init__(self, config):
240
+ super(BertSelfOutput, self).__init__()
241
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
242
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
243
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
244
+
245
+ def forward(self, hidden_states, input_tensor):
246
+ hidden_states = self.dense(hidden_states)
247
+ hidden_states = self.dropout(hidden_states)
248
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
249
+ return hidden_states
250
+
251
+
252
+ class BertAttention(nn.Module):
253
+ def __init__(self, config):
254
+ super(BertAttention, self).__init__()
255
+ self.self = BertSelfAttention(config)
256
+ self.output = BertSelfOutput(config)
257
+ self.pruned_heads = set()
258
+
259
+ def prune_heads(self, heads):
260
+ if len(heads) == 0:
261
+ return
262
+ mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
263
+ heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
264
+ for head in heads:
265
+ # Compute how many pruned heads are before the head and move the index accordingly
266
+ head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
267
+ mask[head] = 0
268
+ mask = mask.view(-1).contiguous().eq(1)
269
+ index = torch.arange(len(mask))[mask].long()
270
+
271
+ # Prune linear layers
272
+ self.self.query = prune_linear_layer(self.self.query, index)
273
+ self.self.key = prune_linear_layer(self.self.key, index)
274
+ self.self.value = prune_linear_layer(self.self.value, index)
275
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
276
+
277
+ # Update hyper params and store pruned heads
278
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
279
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
280
+ self.pruned_heads = self.pruned_heads.union(heads)
281
+
282
+ def forward(self, input_tensor, attention_mask=None, head_mask=None):
283
+ self_outputs = self.self(input_tensor, attention_mask, head_mask)
284
+ attention_output = self.output(self_outputs[0], input_tensor)
285
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
286
+ return outputs
287
+
288
+
289
+ class BertIntermediate(nn.Module):
290
+ def __init__(self, config):
291
+ super(BertIntermediate, self).__init__()
292
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
293
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
294
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
295
+ else:
296
+ self.intermediate_act_fn = config.hidden_act
297
+
298
+ def forward(self, hidden_states):
299
+ hidden_states = self.dense(hidden_states)
300
+ hidden_states = self.intermediate_act_fn(hidden_states)
301
+ return hidden_states
302
+
303
+
304
+ class BertOutput(nn.Module):
305
+ def __init__(self, config):
306
+ super(BertOutput, self).__init__()
307
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
308
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
309
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
310
+
311
+ def forward(self, hidden_states, input_tensor):
312
+ hidden_states = self.dense(hidden_states)
313
+ hidden_states = self.dropout(hidden_states)
314
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
315
+ return hidden_states
316
+
317
+
318
+ class BertLayer(nn.Module):
319
+ def __init__(self, config):
320
+ super(BertLayer, self).__init__()
321
+ self.attention = BertAttention(config)
322
+ self.intermediate = BertIntermediate(config)
323
+ self.output = BertOutput(config)
324
+
325
+ def forward(self, hidden_states, attention_mask=None, head_mask=None):
326
+ attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
327
+ attention_output = attention_outputs[0]
328
+ intermediate_output = self.intermediate(attention_output)
329
+ layer_output = self.output(intermediate_output, attention_output)
330
+ outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
331
+ return outputs
332
+
333
+
334
+ class BertEncoder(nn.Module):
335
+ def __init__(self, config):
336
+ super(BertEncoder, self).__init__()
337
+ self.output_attentions = config.output_attentions
338
+ self.output_hidden_states = config.output_hidden_states
339
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
340
+
341
+ def forward(self, hidden_states, attention_mask=None, head_mask=None):
342
+ all_hidden_states = ()
343
+ all_attentions = ()
344
+ for i, layer_module in enumerate(self.layer):
345
+ if self.output_hidden_states:
346
+ all_hidden_states = all_hidden_states + (hidden_states,)
347
+
348
+ layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
349
+ hidden_states = layer_outputs[0]
350
+
351
+ if self.output_attentions:
352
+ all_attentions = all_attentions + (layer_outputs[1],)
353
+
354
+ # Add last layer
355
+ if self.output_hidden_states:
356
+ all_hidden_states = all_hidden_states + (hidden_states,)
357
+
358
+ outputs = (hidden_states,)
359
+ if self.output_hidden_states:
360
+ outputs = outputs + (all_hidden_states,)
361
+ if self.output_attentions:
362
+ outputs = outputs + (all_attentions,)
363
+ return outputs # last-layer hidden state, (all hidden states), (all attentions)
364
+
365
+
366
+ class BertPooler(nn.Module):
367
+ def __init__(self, config):
368
+ super(BertPooler, self).__init__()
369
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
370
+ self.activation = nn.Tanh()
371
+
372
+ def forward(self, hidden_states):
373
+ # We "pool" the model by simply taking the hidden state corresponding
374
+ # to the first token.
375
+ first_token_tensor = hidden_states[:, 0]
376
+ pooled_output = self.dense(first_token_tensor)
377
+ pooled_output = self.activation(pooled_output)
378
+ return pooled_output
379
+
380
+
381
+ class BertPredictionHeadTransform(nn.Module):
382
+ def __init__(self, config):
383
+ super(BertPredictionHeadTransform, self).__init__()
384
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
385
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
386
+ self.transform_act_fn = ACT2FN[config.hidden_act]
387
+ else:
388
+ self.transform_act_fn = config.hidden_act
389
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
390
+
391
+ def forward(self, hidden_states):
392
+ hidden_states = self.dense(hidden_states)
393
+ hidden_states = self.transform_act_fn(hidden_states)
394
+ hidden_states = self.LayerNorm(hidden_states)
395
+ return hidden_states
396
+
397
+
398
+ class BertLMPredictionHead(nn.Module):
399
+ def __init__(self, config):
400
+ super(BertLMPredictionHead, self).__init__()
401
+ self.transform = BertPredictionHeadTransform(config)
402
+
403
+ # The output weights are the same as the input embeddings, but there is
404
+ # an output-only bias for each token.
405
+ self.decoder = nn.Linear(config.hidden_size,
406
+ config.vocab_size,
407
+ bias=False)
408
+
409
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
410
+
411
+ def forward(self, hidden_states):
412
+ hidden_states = self.transform(hidden_states)
413
+ hidden_states = self.decoder(hidden_states) + self.bias
414
+ return hidden_states
415
+
416
+
417
+ class BertOnlyMLMHead(nn.Module):
418
+ def __init__(self, config):
419
+ super(BertOnlyMLMHead, self).__init__()
420
+ self.predictions = BertLMPredictionHead(config)
421
+
422
+ def forward(self, sequence_output):
423
+ prediction_scores = self.predictions(sequence_output)
424
+ return prediction_scores
425
+
426
+
427
+ class BertOnlyNSPHead(nn.Module):
428
+ def __init__(self, config):
429
+ super(BertOnlyNSPHead, self).__init__()
430
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
431
+
432
+ def forward(self, pooled_output):
433
+ seq_relationship_score = self.seq_relationship(pooled_output)
434
+ return seq_relationship_score
435
+
436
+
437
+ class BertPreTrainingHeads(nn.Module):
438
+ def __init__(self, config):
439
+ super(BertPreTrainingHeads, self).__init__()
440
+ self.predictions = BertLMPredictionHead(config)
441
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
442
+
443
+ def forward(self, sequence_output, pooled_output):
444
+ prediction_scores = self.predictions(sequence_output)
445
+ seq_relationship_score = self.seq_relationship(pooled_output)
446
+ return prediction_scores, seq_relationship_score
447
+
448
+
449
+ class BertPreTrainedModel(PreTrainedModel):
450
+ """ An abstract class to handle weights initialization and
451
+ a simple interface for dowloading and loading pretrained models.
452
+ """
453
+ config_class = BertConfig
454
+ pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
455
+ load_tf_weights = load_tf_weights_in_bert
456
+ base_model_prefix = "bert"
457
+
458
+ def _init_weights(self, module):
459
+ """ Initialize the weights """
460
+ if isinstance(module, (nn.Linear, nn.Embedding)):
461
+ # Slightly different from the TF version which uses truncated_normal for initialization
462
+ # cf https://github.com/pytorch/pytorch/pull/5617
463
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
464
+ elif isinstance(module, BertLayerNorm):
465
+ module.bias.data.zero_()
466
+ module.weight.data.fill_(1.0)
467
+ if isinstance(module, nn.Linear) and module.bias is not None:
468
+ module.bias.data.zero_()
469
+
470
+
471
+ BERT_START_DOCSTRING = r""" The BERT model was proposed in
472
+ `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_
473
+ by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional transformer
474
+ pre-trained using a combination of masked language modeling objective and next sentence prediction
475
+ on a large corpus comprising the Toronto Book Corpus and Wikipedia.
476
+
477
+ This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
478
+ refer to the PyTorch documentation for all matter related to general usage and behavior.
479
+
480
+ .. _`BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`:
481
+ https://arxiv.org/abs/1810.04805
482
+
483
+ .. _`torch.nn.Module`:
484
+ https://pytorch.org/docs/stable/nn.html#module
485
+
486
+ Parameters:
487
+ config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
488
+ Initializing with a config file does not load the weights associated with the model, only the configuration.
489
+ Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
490
+ """
491
+
492
+ BERT_INPUTS_DOCSTRING = r"""
493
+ Inputs:
494
+ **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
495
+ Indices of input sequence tokens in the vocabulary.
496
+ To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:
497
+
498
+ (a) For sequence pairs:
499
+
500
+ ``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
501
+
502
+ ``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
503
+
504
+ (b) For single sequences:
505
+
506
+ ``tokens: [CLS] the dog is hairy . [SEP]``
507
+
508
+ ``token_type_ids: 0 0 0 0 0 0 0``
509
+
510
+ Bert is a model with absolute position embeddings so it's usually advised to pad the inputs on
511
+ the right rather than the left.
512
+
513
+ Indices can be obtained using :class:`transformers.BertTokenizer`.
514
+ See :func:`transformers.PreTrainedTokenizer.encode` and
515
+ :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
516
+ **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
517
+ Mask to avoid performing attention on padding token indices.
518
+ Mask values selected in ``[0, 1]``:
519
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
520
+ **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
521
+ Segment token indices to indicate first and second portions of the inputs.
522
+ Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
523
+ corresponds to a `sentence B` token
524
+ (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
525
+ **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
526
+ Indices of positions of each input sequence tokens in the position embeddings.
527
+ Selected in the range ``[0, config.max_position_embeddings - 1]``.
528
+ **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
529
+ Mask to nullify selected heads of the self-attention modules.
530
+ Mask values selected in ``[0, 1]``:
531
+ ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
532
+ """
533
+
534
+ @add_start_docstrings("The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
535
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
536
+ class BertModel(BertPreTrainedModel):
537
+ r"""
538
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
539
+ **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
540
+ Sequence of hidden-states at the output of the last layer of the model.
541
+ **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
542
+ Last layer hidden-state of the first token of the sequence (classification token)
543
+ further processed by a Linear layer and a Tanh activation function. The Linear
544
+ layer weights are trained from the next sentence prediction (classification)
545
+ objective during Bert pretraining. This output is usually *not* a good summary
546
+ of the semantic content of the input, you're often better with averaging or pooling
547
+ the sequence of hidden-states for the whole input sequence.
548
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
549
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
550
+ of shape ``(batch_size, sequence_length, hidden_size)``:
551
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
552
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
553
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
554
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
555
+
556
+ Examples::
557
+
558
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
559
+ model = BertModel.from_pretrained('bert-base-uncased')
560
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
561
+ outputs = model(input_ids)
562
+ last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
563
+
564
+ """
565
+ def __init__(self, config):
566
+ super(BertModel, self).__init__(config)
567
+
568
+ self.embeddings = BertEmbeddings(config)
569
+ self.encoder = BertEncoder(config)
570
+ self.pooler = BertPooler(config)
571
+
572
+ self.init_weights()
573
+
574
+ def _resize_token_embeddings(self, new_num_tokens):
575
+ old_embeddings = self.embeddings.word_embeddings
576
+ new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
577
+ self.embeddings.word_embeddings = new_embeddings
578
+ return self.embeddings.word_embeddings
579
+
580
+ def _prune_heads(self, heads_to_prune):
581
+ """ Prunes heads of the model.
582
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
583
+ See base class PreTrainedModel
584
+ """
585
+ for layer, heads in heads_to_prune.items():
586
+ self.encoder.layer[layer].attention.prune_heads(heads)
587
+
588
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
589
+ if attention_mask is None:
590
+ attention_mask = torch.ones_like(input_ids)
591
+ if token_type_ids is None:
592
+ token_type_ids = torch.zeros_like(input_ids)
593
+
594
+ # We create a 3D attention mask from a 2D tensor mask.
595
+ # Sizes are [batch_size, 1, 1, to_seq_length]
596
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
597
+ # this attention mask is more simple than the triangular masking of causal attention
598
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
599
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
600
+
601
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
602
+ # masked positions, this operation will create a tensor which is 0.0 for
603
+ # positions we want to attend and -10000.0 for masked positions.
604
+ # Since we are adding it to the raw scores before the softmax, this is
605
+ # effectively the same as removing these entirely.
606
+ extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
607
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
608
+
609
+ # Prepare head mask if needed
610
+ # 1.0 in head_mask indicate we keep the head
611
+ # attention_probs has shape bsz x n_heads x N x N
612
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
613
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
614
+ if head_mask is not None:
615
+ if head_mask.dim() == 1:
616
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
617
+ head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
618
+ elif head_mask.dim() == 2:
619
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
620
+ head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
621
+ else:
622
+ head_mask = [None] * self.config.num_hidden_layers
623
+
624
+ embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
625
+ encoder_outputs = self.encoder(embedding_output,
626
+ extended_attention_mask,
627
+ head_mask=head_mask)
628
+ sequence_output = encoder_outputs[0]
629
+ pooled_output = self.pooler(sequence_output)
630
+
631
+ outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
632
+ return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
633
+
634
+
635
+ @add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
636
+ a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
637
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
638
+ class BertForPreTraining(BertPreTrainedModel):
639
+ r"""
640
+ **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
641
+ Labels for computing the masked language modeling loss.
642
+ Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
643
+ Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
644
+ in ``[0, ..., config.vocab_size]``
645
+ **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
646
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
647
+ Indices should be in ``[0, 1]``.
648
+ ``0`` indicates sequence B is a continuation of sequence A,
649
+ ``1`` indicates sequence B is a random sequence.
650
+
651
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
652
+ **loss**: (`optional`, returned when both ``masked_lm_labels`` and ``next_sentence_label`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
653
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
654
+ **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
655
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
656
+ **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
657
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
658
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
659
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
660
+ of shape ``(batch_size, sequence_length, hidden_size)``:
661
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
662
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
663
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
664
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
665
+
666
+ Examples::
667
+
668
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
669
+ model = BertForPreTraining.from_pretrained('bert-base-uncased')
670
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
671
+ outputs = model(input_ids)
672
+ prediction_scores, seq_relationship_scores = outputs[:2]
673
+
674
+ """
675
+ def __init__(self, config):
676
+ super(BertForPreTraining, self).__init__(config)
677
+
678
+ self.bert = BertModel(config)
679
+ self.cls = BertPreTrainingHeads(config)
680
+
681
+ self.init_weights()
682
+ self.tie_weights()
683
+
684
+ def tie_weights(self):
685
+ """ Make sure we are sharing the input and output embeddings.
686
+ Export to TorchScript can't handle parameter sharing so we are cloning them instead.
687
+ """
688
+ self._tie_or_clone_weights(self.cls.predictions.decoder,
689
+ self.bert.embeddings.word_embeddings)
690
+
691
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
692
+ masked_lm_labels=None, next_sentence_label=None):
693
+
694
+ outputs = self.bert(input_ids,
695
+ attention_mask=attention_mask,
696
+ token_type_ids=token_type_ids,
697
+ position_ids=position_ids,
698
+ head_mask=head_mask)
699
+
700
+ sequence_output, pooled_output = outputs[:2]
701
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
702
+
703
+ outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
704
+
705
+ if masked_lm_labels is not None and next_sentence_label is not None:
706
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
707
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
708
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
709
+ total_loss = masked_lm_loss + next_sentence_loss
710
+ outputs = (total_loss,) + outputs
711
+
712
+ return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
713
+
714
+
715
+ @add_start_docstrings("""Bert Model with a `language modeling` head on top. """,
716
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
717
+ class BertForMaskedLM(BertPreTrainedModel):
718
+ r"""
719
+ **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
720
+ Labels for computing the masked language modeling loss.
721
+ Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
722
+ Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
723
+ in ``[0, ..., config.vocab_size]``
724
+
725
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
726
+ **loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
727
+ Masked language modeling loss.
728
+ **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
729
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
730
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
731
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
732
+ of shape ``(batch_size, sequence_length, hidden_size)``:
733
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
734
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
735
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
736
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
737
+
738
+ Examples::
739
+
740
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
741
+ model = BertForMaskedLM.from_pretrained('bert-base-uncased')
742
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
743
+ outputs = model(input_ids, masked_lm_labels=input_ids)
744
+ loss, prediction_scores = outputs[:2]
745
+
746
+ """
747
+ def __init__(self, config):
748
+ super(BertForMaskedLM, self).__init__(config)
749
+
750
+ self.bert = BertModel(config)
751
+ self.cls = BertOnlyMLMHead(config)
752
+
753
+ self.init_weights()
754
+ self.tie_weights()
755
+
756
+ def tie_weights(self):
757
+ """ Make sure we are sharing the input and output embeddings.
758
+ Export to TorchScript can't handle parameter sharing so we are cloning them instead.
759
+ """
760
+ self._tie_or_clone_weights(self.cls.predictions.decoder,
761
+ self.bert.embeddings.word_embeddings)
762
+
763
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
764
+ masked_lm_labels=None):
765
+
766
+ outputs = self.bert(input_ids,
767
+ attention_mask=attention_mask,
768
+ token_type_ids=token_type_ids,
769
+ position_ids=position_ids,
770
+ head_mask=head_mask)
771
+
772
+ sequence_output = outputs[0]
773
+ prediction_scores = self.cls(sequence_output)
774
+
775
+ outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
776
+ if masked_lm_labels is not None:
777
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
778
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
779
+ outputs = (masked_lm_loss,) + outputs
780
+
781
+ return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
782
+
783
+
784
+ @add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
785
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
786
+ class BertForNextSentencePrediction(BertPreTrainedModel):
787
+ r"""
788
+ **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
789
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
790
+ Indices should be in ``[0, 1]``.
791
+ ``0`` indicates sequence B is a continuation of sequence A,
792
+ ``1`` indicates sequence B is a random sequence.
793
+
794
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
795
+ **loss**: (`optional`, returned when ``next_sentence_label`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
796
+ Next sequence prediction (classification) loss.
797
+ **seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
798
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
799
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
800
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
801
+ of shape ``(batch_size, sequence_length, hidden_size)``:
802
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
803
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
804
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
805
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
806
+
807
+ Examples::
808
+
809
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
810
+ model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
811
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
812
+ outputs = model(input_ids)
813
+ seq_relationship_scores = outputs[0]
814
+
815
+ """
816
+ def __init__(self, config):
817
+ super(BertForNextSentencePrediction, self).__init__(config)
818
+
819
+ self.bert = BertModel(config)
820
+ self.cls = BertOnlyNSPHead(config)
821
+
822
+ self.init_weights()
823
+
824
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
825
+ next_sentence_label=None):
826
+
827
+ outputs = self.bert(input_ids,
828
+ attention_mask=attention_mask,
829
+ token_type_ids=token_type_ids,
830
+ position_ids=position_ids,
831
+ head_mask=head_mask)
832
+
833
+ pooled_output = outputs[1]
834
+
835
+ seq_relationship_score = self.cls(pooled_output)
836
+
837
+ outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
838
+ if next_sentence_label is not None:
839
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
840
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
841
+ outputs = (next_sentence_loss,) + outputs
842
+
843
+ return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
844
+
845
+
846
+ @add_start_docstrings("""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
847
+ the pooled output) e.g. for GLUE tasks. """,
848
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
849
+ class BertForSequenceClassification(BertPreTrainedModel):
850
+ r"""
851
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
852
+ Labels for computing the sequence classification/regression loss.
853
+ Indices should be in ``[0, ..., config.num_labels - 1]``.
854
+ If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
855
+ If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
856
+
857
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
858
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
859
+ Classification (or regression if config.num_labels==1) loss.
860
+ **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
861
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
862
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
863
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
864
+ of shape ``(batch_size, sequence_length, hidden_size)``:
865
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
866
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
867
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
868
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
869
+
870
+ Examples::
871
+
872
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
873
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
874
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
875
+ labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
876
+ outputs = model(input_ids, labels=labels)
877
+ loss, logits = outputs[:2]
878
+
879
+ """
880
+ def __init__(self, config):
881
+ super(BertForSequenceClassification, self).__init__(config)
882
+ self.num_labels = config.num_labels
883
+
884
+ self.bert = BertModel(config)
885
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
886
+ self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
887
+
888
+ self.init_weights()
889
+
890
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None,
891
+ position_ids=None, head_mask=None, labels=None):
892
+
893
+ outputs = self.bert(input_ids,
894
+ attention_mask=attention_mask,
895
+ token_type_ids=token_type_ids,
896
+ position_ids=position_ids,
897
+ head_mask=head_mask)
898
+
899
+ pooled_output = outputs[1]
900
+
901
+ pooled_output = self.dropout(pooled_output)
902
+ logits = self.classifier(pooled_output)
903
+
904
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
905
+
906
+ if labels is not None:
907
+ if self.num_labels == 1:
908
+ # We are doing regression
909
+ loss_fct = MSELoss()
910
+ loss = loss_fct(logits.view(-1), labels.view(-1))
911
+ else:
912
+ loss_fct = CrossEntropyLoss()
913
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
914
+ outputs = (loss,) + outputs
915
+
916
+ return outputs # (loss), logits, (hidden_states), (attentions)
917
+
918
+
919
+ @add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
920
+ the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
921
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
922
+ class BertForMultipleChoice(BertPreTrainedModel):
923
+ r"""
924
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
925
+ Labels for computing the multiple choice classification loss.
926
+ Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
927
+ of the input tensors. (see `input_ids` above)
928
+
929
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
930
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
931
+ Classification loss.
932
+ **classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
933
+ of the input tensors. (see `input_ids` above).
934
+ Classification scores (before SoftMax).
935
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
936
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
937
+ of shape ``(batch_size, sequence_length, hidden_size)``:
938
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
939
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
940
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
941
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
942
+
943
+ Examples::
944
+
945
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
946
+ model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
947
+ choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
948
+ input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
949
+ labels = torch.tensor(1).unsqueeze(0) # Batch size 1
950
+ outputs = model(input_ids, labels=labels)
951
+ loss, classification_scores = outputs[:2]
952
+
953
+ """
954
+ def __init__(self, config):
955
+ super(BertForMultipleChoice, self).__init__(config)
956
+
957
+ self.bert = BertModel(config)
958
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
959
+ self.classifier = nn.Linear(config.hidden_size, 1)
960
+
961
+ self.init_weights()
962
+
963
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None,
964
+ position_ids=None, head_mask=None, labels=None):
965
+ num_choices = input_ids.shape[1]
966
+
967
+ input_ids = input_ids.view(-1, input_ids.size(-1))
968
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
969
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
970
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
971
+
972
+ outputs = self.bert(input_ids,
973
+ attention_mask=attention_mask,
974
+ token_type_ids=token_type_ids,
975
+ position_ids=position_ids,
976
+ head_mask=head_mask)
977
+
978
+ pooled_output = outputs[1]
979
+
980
+ pooled_output = self.dropout(pooled_output)
981
+ logits = self.classifier(pooled_output)
982
+ reshaped_logits = logits.view(-1, num_choices)
983
+
984
+ outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
985
+
986
+ if labels is not None:
987
+ loss_fct = CrossEntropyLoss()
988
+ loss = loss_fct(reshaped_logits, labels)
989
+ outputs = (loss,) + outputs
990
+
991
+ return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
992
+
993
+
994
+ @add_start_docstrings("""Bert Model with a token classification head on top (a linear layer on top of
995
+ the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
996
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
997
+ class BertForTokenClassification(BertPreTrainedModel):
998
+ r"""
999
+ **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
1000
+ Labels for computing the token classification loss.
1001
+ Indices should be in ``[0, ..., config.num_labels - 1]``.
1002
+
1003
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
1004
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
1005
+ Classification loss.
1006
+ **scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
1007
+ Classification scores (before SoftMax).
1008
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
1009
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
1010
+ of shape ``(batch_size, sequence_length, hidden_size)``:
1011
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1012
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
1013
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
1014
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
1015
+
1016
+ Examples::
1017
+
1018
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1019
+ model = BertForTokenClassification.from_pretrained('bert-base-uncased')
1020
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
1021
+ labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
1022
+ outputs = model(input_ids, labels=labels)
1023
+ loss, scores = outputs[:2]
1024
+
1025
+ """
1026
+ def __init__(self, config):
1027
+ super(BertForTokenClassification, self).__init__(config)
1028
+ self.num_labels = config.num_labels
1029
+
1030
+ self.bert = BertModel(config)
1031
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1032
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1033
+
1034
+ self.init_weights()
1035
+
1036
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None,
1037
+ position_ids=None, head_mask=None, labels=None):
1038
+
1039
+ outputs = self.bert(input_ids,
1040
+ attention_mask=attention_mask,
1041
+ token_type_ids=token_type_ids,
1042
+ position_ids=position_ids,
1043
+ head_mask=head_mask)
1044
+
1045
+ sequence_output = outputs[0]
1046
+
1047
+ sequence_output = self.dropout(sequence_output)
1048
+ logits = self.classifier(sequence_output)
1049
+
1050
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
1051
+ if labels is not None:
1052
+ loss_fct = CrossEntropyLoss()
1053
+ # Only keep active parts of the loss
1054
+ if attention_mask is not None:
1055
+ active_loss = attention_mask.view(-1) == 1
1056
+ active_logits = logits.view(-1, self.num_labels)[active_loss]
1057
+ active_labels = labels.view(-1)[active_loss]
1058
+ loss = loss_fct(active_logits, active_labels)
1059
+ else:
1060
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1061
+ outputs = (loss,) + outputs
1062
+
1063
+ return outputs # (loss), scores, (hidden_states), (attentions)
1064
+
1065
+
1066
+ @add_start_docstrings("""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
1067
+ the hidden-states output to compute `span start logits` and `span end logits`). """,
1068
+ BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
1069
+ class BertForQuestionAnswering(BertPreTrainedModel):
1070
+ r"""
1071
+ **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
1072
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1073
+ Positions are clamped to the length of the sequence (`sequence_length`).
1074
+ Position outside of the sequence are not taken into account for computing the loss.
1075
+ **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
1076
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1077
+ Positions are clamped to the length of the sequence (`sequence_length`).
1078
+ Position outside of the sequence are not taken into account for computing the loss.
1079
+
1080
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
1081
+ **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
1082
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
1083
+ **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
1084
+ Span-start scores (before SoftMax).
1085
+ **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
1086
+ Span-end scores (before SoftMax).
1087
+ **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
1088
+ list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
1089
+ of shape ``(batch_size, sequence_length, hidden_size)``:
1090
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1091
+ **attentions**: (`optional`, returned when ``config.output_attentions=True``)
1092
+ list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
1093
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
1094
+
1095
+ Examples::
1096
+
1097
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1098
+ model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
1099
+ input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
1100
+ start_positions = torch.tensor([1])
1101
+ end_positions = torch.tensor([3])
1102
+ outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
1103
+ loss, start_scores, end_scores = outputs[:2]
1104
+
1105
+ """
1106
+ def __init__(self, config):
1107
+ super(BertForQuestionAnswering, self).__init__(config)
1108
+ self.num_labels = config.num_labels
1109
+
1110
+ self.bert = BertModel(config)
1111
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1112
+
1113
+ self.init_weights()
1114
+
1115
+ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
1116
+ start_positions=None, end_positions=None):
1117
+
1118
+ outputs = self.bert(input_ids,
1119
+ attention_mask=attention_mask,
1120
+ token_type_ids=token_type_ids,
1121
+ position_ids=position_ids,
1122
+ head_mask=head_mask)
1123
+
1124
+ sequence_output = outputs[0]
1125
+
1126
+ logits = self.qa_outputs(sequence_output)
1127
+ start_logits, end_logits = logits.split(1, dim=-1)
1128
+ start_logits = start_logits.squeeze(-1)
1129
+ end_logits = end_logits.squeeze(-1)
1130
+
1131
+ outputs = (start_logits, end_logits,) + outputs[2:]
1132
+ if start_positions is not None and end_positions is not None:
1133
+ # If we are on multi-GPU, split add a dimension
1134
+ if len(start_positions.size()) > 1:
1135
+ start_positions = start_positions.squeeze(-1)
1136
+ if len(end_positions.size()) > 1:
1137
+ end_positions = end_positions.squeeze(-1)
1138
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1139
+ ignored_index = start_logits.size(1)
1140
+ start_positions.clamp_(0, ignored_index)
1141
+ end_positions.clamp_(0, ignored_index)
1142
+
1143
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1144
+ start_loss = loss_fct(start_logits, start_positions)
1145
+ end_loss = loss_fct(end_logits, end_positions)
1146
+ total_loss = (start_loss + end_loss) / 2
1147
+ outputs = (total_loss,) + outputs
1148
+
1149
+ return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
model/modeling_utils.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch BERT model."""
2
+
3
+ from __future__ import (absolute_import, division, print_function,
4
+ unicode_literals)
5
+
6
+ import logging
7
+ import os
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import CrossEntropyLoss
12
+ from torch.nn import functional as F
13
+ from model.configuration_utils import PretrainedConfig
14
+ from model.file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME
15
+ # from model.configuration_utils import PretrainedConfig
16
+ # from model.file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ try:
22
+ from torch.nn import Identity
23
+ except ImportError:
24
+ # Older PyTorch compatibility
25
+ class Identity(nn.Module):
26
+ r"""A placeholder identity operator that is argument-insensitive.
27
+ """
28
+ def __init__(self, *args, **kwargs):
29
+ super(Identity, self).__init__()
30
+
31
+ def forward(self, input):
32
+ return input
33
+
34
+ class PreTrainedModel(nn.Module):
35
+ r""" Base class for all models.
36
+
37
+ :class:`~pytorch_transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
38
+ as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
39
+
40
+ Class attributes (overridden by derived classes):
41
+ - ``config_class``: a class derived from :class:`~pytorch_transformers.PretrainedConfig` to use as configuration class for this model architecture.
42
+ - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
43
+ - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
44
+
45
+ - ``model``: an instance of the relevant subclass of :class:`~pytorch_transformers.PreTrainedModel`,
46
+ - ``config``: an instance of the relevant subclass of :class:`~pytorch_transformers.PretrainedConfig`,
47
+ - ``path``: a path (string) to the TensorFlow checkpoint.
48
+
49
+ - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
50
+ """
51
+ config_class = None
52
+ pretrained_model_archive_map = {}
53
+ load_tf_weights = lambda model, config, path: None
54
+ base_model_prefix = ""
55
+
56
+ def __init__(self, config, *inputs, **kwargs):
57
+ super(PreTrainedModel, self).__init__()
58
+ if not isinstance(config, PretrainedConfig):
59
+ raise ValueError(
60
+ "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
61
+ "To create a model from a pretrained model use "
62
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
63
+ self.__class__.__name__, self.__class__.__name__
64
+ ))
65
+ # Save config in model
66
+ self.config = config
67
+
68
+ def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
69
+ """ Build a resized Embedding Module from a provided token Embedding Module.
70
+ Increasing the size will add newly initialized vectors at the end
71
+ Reducing the size will remove vectors from the end
72
+
73
+ Args:
74
+ new_num_tokens: (`optional`) int
75
+ New number of tokens in the embedding matrix.
76
+ Increasing the size will add newly initialized vectors at the end
77
+ Reducing the size will remove vectors from the end
78
+ If not provided or None: return the provided token Embedding Module.
79
+ Return: ``torch.nn.Embeddings``
80
+ Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
81
+ """
82
+ if new_num_tokens is None:
83
+ return old_embeddings
84
+
85
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
86
+ if old_num_tokens == new_num_tokens:
87
+ return old_embeddings
88
+
89
+ # Build new embeddings
90
+ new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
91
+ new_embeddings.to(old_embeddings.weight.device)
92
+
93
+ # initialize all new embeddings (in particular added tokens)
94
+ self._init_weights(new_embeddings)
95
+
96
+ # Copy word embeddings from the previous weights
97
+ num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
98
+ new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
99
+
100
+ return new_embeddings
101
+
102
+ def _tie_or_clone_weights(self, first_module, second_module):
103
+ """ Tie or clone module weights depending of weither we are using TorchScript or not
104
+ """
105
+
106
+ if self.config.torchscript:
107
+ first_module.weight = nn.Parameter(second_module.weight.clone())
108
+ else:
109
+ first_module.weight = second_module.weight
110
+
111
+
112
+ if hasattr(first_module, 'bias') and first_module.bias is not None:
113
+ first_module.bias.data = torch.nn.functional.pad(
114
+ first_module.bias.data,
115
+ (0, first_module.weight.shape[0] - first_module.bias.shape[0]),
116
+ 'constant',
117
+ 0
118
+ )
119
+
120
+ def resize_token_embeddings(self, new_num_tokens=None):
121
+ """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
122
+ Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
123
+
124
+ Arguments:
125
+
126
+ new_num_tokens: (`optional`) int:
127
+ New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
128
+ If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
129
+
130
+ Return: ``torch.nn.Embeddings``
131
+ Pointer to the input tokens Embeddings Module of the model
132
+ """
133
+ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
134
+ model_embeds = base_model._resize_token_embeddings(new_num_tokens)
135
+ if new_num_tokens is None:
136
+ return model_embeds
137
+
138
+ # Update base model and current model config
139
+ self.config.vocab_size = new_num_tokens
140
+ base_model.vocab_size = new_num_tokens
141
+
142
+ # Tie weights again if needed
143
+ if hasattr(self, 'tie_weights'):
144
+ self.tie_weights()
145
+
146
+ return model_embeds
147
+
148
+ def init_weights(self):
149
+ """ Initialize and prunes weights if needed. """
150
+ # Initialize weights
151
+ self.apply(self._init_weights)
152
+
153
+ # Prune heads if needed
154
+ if self.config.pruned_heads:
155
+ self.prune_heads(self.config.pruned_heads)
156
+
157
+ def prune_heads(self, heads_to_prune):
158
+ """ Prunes heads of the base model.
159
+
160
+ Arguments:
161
+
162
+ heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
163
+ E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
164
+ """
165
+ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
166
+
167
+ # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
168
+ for layer, heads in heads_to_prune.items():
169
+ union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
170
+ self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
171
+
172
+ base_model._prune_heads(heads_to_prune)
173
+
174
+ def save_pretrained(self, save_directory):
175
+ """ Save a model and its configuration file to a directory, so that it
176
+ can be re-loaded using the `:func:`~pytorch_transformers.PreTrainedModel.from_pretrained`` class method.
177
+ """
178
+ assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
179
+
180
+ # Only save the model it-self if we are using distributed training
181
+ model_to_save = self.module if hasattr(self, 'module') else self
182
+
183
+ # Save configuration file
184
+ model_to_save.config.save_pretrained(save_directory)
185
+
186
+ # If we save using the predefined names, we can load using `from_pretrained`
187
+ output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
188
+
189
+ torch.save(model_to_save.state_dict(), output_model_file)
190
+
191
+ @classmethod
192
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
193
+ r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
194
+
195
+ The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
196
+ To train the model, you should first set it back in training mode with ``model.train()``
197
+
198
+ The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
199
+ It is up to you to train those weights with a downstream fine-tuning task.
200
+
201
+ The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
202
+
203
+ Parameters:
204
+ pretrained_model_name_or_path: either:
205
+
206
+ - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
207
+ - a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
208
+ - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
209
+
210
+ model_args: (`optional`) Sequence of positional arguments:
211
+ All remaning positional arguments will be passed to the underlying model's ``__init__`` method
212
+
213
+ config: (`optional`) instance of a class derived from :class:`~pytorch_transformers.PretrainedConfig`:
214
+ Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
215
+
216
+ - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
217
+ - the model was saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
218
+ - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
219
+
220
+ state_dict: (`optional`) dict:
221
+ an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
222
+ This option can be used if you want to create a model from a pretrained configuration but load your own weights.
223
+ In this case though, you should check if using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and :func:`~pytorch_transformers.PreTrainedModel.from_pretrained` is not a simpler option.
224
+
225
+ cache_dir: (`optional`) string:
226
+ Path to a directory in which a downloaded pre-trained model
227
+ configuration should be cached if the standard cache should not be used.
228
+
229
+ force_download: (`optional`) boolean, default False:
230
+ Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
231
+
232
+ proxies: (`optional`) dict, default None:
233
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
234
+ The proxies are used on each request.
235
+
236
+ output_loading_info: (`optional`) boolean:
237
+ Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
238
+
239
+ kwargs: (`optional`) Remaining dictionary of keyword arguments:
240
+ Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
241
+
242
+ - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
243
+ - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
244
+
245
+ Examples::
246
+
247
+ model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
248
+ model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
249
+ model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
250
+ assert model.config.output_attention == True
251
+ # Loading from a TF checkpoint file instead of a PyTorch model (slower)
252
+ config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
253
+ model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
254
+
255
+ """
256
+ config = kwargs.pop('config', None)
257
+ state_dict = kwargs.pop('state_dict', None)
258
+ cache_dir = kwargs.pop('cache_dir', None)
259
+ from_tf = kwargs.pop('from_tf', False)
260
+ force_download = kwargs.pop('force_download', False)
261
+ proxies = kwargs.pop('proxies', None)
262
+ output_loading_info = kwargs.pop('output_loading_info', False)
263
+
264
+ # Load config
265
+ if config is None:
266
+ config, model_kwargs = cls.config_class.from_pretrained(
267
+ pretrained_model_name_or_path, *model_args,
268
+ cache_dir=cache_dir, return_unused_kwargs=True,
269
+ force_download=force_download,
270
+ **kwargs
271
+ )
272
+ else:
273
+ model_kwargs = kwargs
274
+
275
+ # Load model
276
+ if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
277
+ archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
278
+ elif os.path.isdir(pretrained_model_name_or_path):
279
+ if from_tf:
280
+ # Directly load from a TensorFlow checkpoint
281
+ archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
282
+ else:
283
+ archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
284
+ else:
285
+ if from_tf:
286
+ # Directly load from a TensorFlow checkpoint
287
+ archive_file = pretrained_model_name_or_path + ".index"
288
+ else:
289
+ archive_file = pretrained_model_name_or_path
290
+ # redirect to the cache, if necessary
291
+ try:
292
+ resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
293
+ except EnvironmentError as e:
294
+ if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
295
+ logger.error(
296
+ "Couldn't reach server at '{}' to download pretrained weights.".format(
297
+ archive_file))
298
+ else:
299
+ logger.error(
300
+ "Model name '{}' was not found in model name list ({}). "
301
+ "We assumed '{}' was a path or url but couldn't find any file "
302
+ "associated to this path or url.".format(
303
+ pretrained_model_name_or_path,
304
+ ', '.join(cls.pretrained_model_archive_map.keys()),
305
+ archive_file))
306
+ raise e
307
+ if resolved_archive_file == archive_file:
308
+ logger.info("loading weights file {}".format(archive_file))
309
+ else:
310
+ logger.info("loading weights file {} from cache at {}".format(
311
+ archive_file, resolved_archive_file))
312
+
313
+ # Instantiate model.
314
+ model = cls(config, *model_args, **model_kwargs)
315
+
316
+ if state_dict is None and not from_tf:
317
+ state_dict = torch.load(resolved_archive_file, map_location='cpu')
318
+ if from_tf:
319
+ # Directly load from a TensorFlow checkpoint
320
+ return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
321
+
322
+ # Convert old format to new format if needed from a PyTorch state_dict
323
+ old_keys = []
324
+ new_keys = []
325
+ for key in state_dict.keys():
326
+ new_key = None
327
+ if 'gamma' in key:
328
+ new_key = key.replace('gamma', 'weight')
329
+ if 'beta' in key:
330
+ new_key = key.replace('beta', 'bias')
331
+ if new_key:
332
+ old_keys.append(key)
333
+ new_keys.append(new_key)
334
+ for old_key, new_key in zip(old_keys, new_keys):
335
+ state_dict[new_key] = state_dict.pop(old_key)
336
+
337
+ # Load from a PyTorch state_dict
338
+ missing_keys = []
339
+ unexpected_keys = []
340
+ error_msgs = []
341
+ # copy state_dict so _load_from_state_dict can modify it
342
+ metadata = getattr(state_dict, '_metadata', None)
343
+ state_dict = state_dict.copy()
344
+ if metadata is not None:
345
+ state_dict._metadata = metadata
346
+
347
+ def load(module, prefix=''):
348
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
349
+ module._load_from_state_dict(
350
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
351
+ for name, child in module._modules.items():
352
+ if child is not None:
353
+ load(child, prefix + name + '.')
354
+
355
+ # Make sure we are able to load base models as well as derived models (with heads)
356
+ start_prefix = ''
357
+ model_to_load = model
358
+ if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
359
+ start_prefix = cls.base_model_prefix + '.'
360
+ if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
361
+ model_to_load = getattr(model, cls.base_model_prefix)
362
+
363
+ load(model_to_load, prefix=start_prefix)
364
+ if len(missing_keys) > 0:
365
+ logger.info("Weights of {} not initialized from pretrained model: {}".format(
366
+ model.__class__.__name__, missing_keys))
367
+ if len(unexpected_keys) > 0:
368
+ logger.info("Weights from pretrained model not used in {}: {}".format(
369
+ model.__class__.__name__, unexpected_keys))
370
+ if len(error_msgs) > 0:
371
+ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
372
+ model.__class__.__name__, "\n\t".join(error_msgs)))
373
+
374
+ if hasattr(model, 'tie_weights'):
375
+ model.tie_weights() # make sure word embedding weights are still tied
376
+
377
+ # Set model in evaluation mode to desactivate DropOut modules by default
378
+ model.eval()
379
+
380
+ if output_loading_info:
381
+ loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
382
+ return model, loading_info
383
+
384
+ return model
385
+
386
+
387
+ class Conv1D(nn.Module):
388
+ def __init__(self, nf, nx):
389
+ """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
390
+ Basically works like a Linear layer but the weights are transposed
391
+ """
392
+ super(Conv1D, self).__init__()
393
+ self.nf = nf
394
+ w = torch.empty(nx, nf)
395
+ nn.init.normal_(w, std=0.02)
396
+ self.weight = nn.Parameter(w)
397
+ self.bias = nn.Parameter(torch.zeros(nf))
398
+
399
+ def forward(self, x):
400
+ size_out = x.size()[:-1] + (self.nf,)
401
+ x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
402
+ x = x.view(*size_out)
403
+ return x
404
+
405
+
406
+ class PoolerStartLogits(nn.Module):
407
+ """ Compute SQuAD start_logits from sequence hidden states. """
408
+ def __init__(self, config):
409
+ super(PoolerStartLogits, self).__init__()
410
+ self.dense = nn.Linear(config.hidden_size, 1)
411
+
412
+ def forward(self, hidden_states, p_mask=None):
413
+ """ Args:
414
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
415
+ invalid position mask such as query and special symbols (PAD, SEP, CLS)
416
+ 1.0 means token should be masked.
417
+ """
418
+ x = self.dense(hidden_states).squeeze(-1)
419
+
420
+ if p_mask is not None:
421
+ x = x * (1 - p_mask) - 1e30 * p_mask
422
+
423
+ return x
424
+
425
+
426
+ class PoolerEndLogits(nn.Module):
427
+ """ Compute SQuAD end_logits from sequence hidden states and start token hidden state.
428
+ """
429
+ def __init__(self, config):
430
+ super(PoolerEndLogits, self).__init__()
431
+ self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
432
+ self.activation = nn.Tanh()
433
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
434
+ self.dense_1 = nn.Linear(config.hidden_size, 1)
435
+
436
+ def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
437
+ """ Args:
438
+ One of ``start_states``, ``start_positions`` should be not None.
439
+ If both are set, ``start_positions`` overrides ``start_states``.
440
+
441
+ **start_states**: ``torch.LongTensor`` of shape identical to hidden_states
442
+ hidden states of the first tokens for the labeled span.
443
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
444
+ position of the first token for the labeled span:
445
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
446
+ Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
447
+ 1.0 means token should be masked.
448
+ """
449
+ assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
450
+ if start_positions is not None:
451
+ slen, hsz = hidden_states.shape[-2:]
452
+ start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
453
+ start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
454
+ start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
455
+
456
+ x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
457
+ x = self.activation(x)
458
+ x = self.LayerNorm(x)
459
+ x = self.dense_1(x).squeeze(-1)
460
+
461
+ if p_mask is not None:
462
+ x = x * (1 - p_mask) - 1e30 * p_mask
463
+
464
+ return x
465
+
466
+
467
+ class PoolerAnswerClass(nn.Module):
468
+ """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
469
+ def __init__(self, config):
470
+ super(PoolerAnswerClass, self).__init__()
471
+ self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
472
+ self.activation = nn.Tanh()
473
+ self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
474
+
475
+ def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
476
+ """
477
+ Args:
478
+ One of ``start_states``, ``start_positions`` should be not None.
479
+ If both are set, ``start_positions`` overrides ``start_states``.
480
+
481
+ **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
482
+ hidden states of the first tokens for the labeled span.
483
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
484
+ position of the first token for the labeled span.
485
+ **cls_index**: torch.LongTensor of shape ``(batch_size,)``
486
+ position of the CLS token. If None, take the last token.
487
+
488
+ note(Original repo):
489
+ no dependency on end_feature so that we can obtain one single `cls_logits`
490
+ for each sample
491
+ """
492
+ hsz = hidden_states.shape[-1]
493
+ assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
494
+ if start_positions is not None:
495
+ start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
496
+ start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
497
+
498
+ if cls_index is not None:
499
+ cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
500
+ cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
501
+ else:
502
+ cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
503
+
504
+ x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
505
+ x = self.activation(x)
506
+ x = self.dense_1(x).squeeze(-1)
507
+
508
+ return x
509
+
510
+
511
+ class SQuADHead(nn.Module):
512
+ r""" A SQuAD head inspired by XLNet.
513
+
514
+ Parameters:
515
+ config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
516
+
517
+ Inputs:
518
+ **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
519
+ hidden states of sequence tokens
520
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
521
+ position of the first token for the labeled span.
522
+ **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
523
+ position of the last token for the labeled span.
524
+ **cls_index**: torch.LongTensor of shape ``(batch_size,)``
525
+ position of the CLS token. If None, take the last token.
526
+ **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
527
+ Whether the question has a possible answer in the paragraph or not.
528
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
529
+ Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
530
+ 1.0 means token should be masked.
531
+
532
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
533
+ **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
534
+ Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
535
+ **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
536
+ ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
537
+ Log probabilities for the top config.start_n_top start token possibilities (beam-search).
538
+ **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
539
+ ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
540
+ Indices for the top config.start_n_top start token possibilities (beam-search).
541
+ **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
542
+ ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
543
+ Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
544
+ **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
545
+ ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
546
+ Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
547
+ **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
548
+ ``torch.FloatTensor`` of shape ``(batch_size,)``
549
+ Log probabilities for the ``is_impossible`` label of the answers.
550
+ """
551
+ def __init__(self, config):
552
+ super(SQuADHead, self).__init__()
553
+ self.start_n_top = config.start_n_top
554
+ self.end_n_top = config.end_n_top
555
+
556
+ self.start_logits = PoolerStartLogits(config)
557
+ self.end_logits = PoolerEndLogits(config)
558
+ self.answer_class = PoolerAnswerClass(config)
559
+
560
+ def forward(self, hidden_states, start_positions=None, end_positions=None,
561
+ cls_index=None, is_impossible=None, p_mask=None):
562
+ outputs = ()
563
+
564
+ start_logits = self.start_logits(hidden_states, p_mask=p_mask)
565
+
566
+ if start_positions is not None and end_positions is not None:
567
+ # If we are on multi-GPU, let's remove the dimension added by batch splitting
568
+ for x in (start_positions, end_positions, cls_index, is_impossible):
569
+ if x is not None and x.dim() > 1:
570
+ x.squeeze_(-1)
571
+
572
+ # during training, compute the end logits based on the ground truth of the start position
573
+ end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
574
+
575
+ loss_fct = CrossEntropyLoss()
576
+ start_loss = loss_fct(start_logits, start_positions)
577
+ end_loss = loss_fct(end_logits, end_positions)
578
+ total_loss = (start_loss + end_loss) / 2
579
+
580
+ if cls_index is not None and is_impossible is not None:
581
+ # Predict answerability from the representation of CLS and START
582
+ cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
583
+ loss_fct_cls = nn.BCEWithLogitsLoss()
584
+ cls_loss = loss_fct_cls(cls_logits, is_impossible)
585
+
586
+ # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
587
+ total_loss += cls_loss * 0.5
588
+
589
+ outputs = (total_loss,) + outputs
590
+
591
+ else:
592
+ # during inference, compute the end logits based on beam search
593
+ bsz, slen, hsz = hidden_states.size()
594
+ start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
595
+
596
+ start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
597
+ start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
598
+ start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
599
+ start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
600
+
601
+ hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
602
+ p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
603
+ end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
604
+ end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
605
+
606
+ end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top)
607
+ end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
608
+ end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
609
+
610
+ start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
611
+ cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
612
+
613
+ outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
614
+
615
+ # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
616
+ # or (if labels are provided) (total_loss,)
617
+ return outputs
618
+
619
+
620
+ class SequenceSummary(nn.Module):
621
+ r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
622
+ Args of the config class:
623
+ summary_type:
624
+ - 'last' => [default] take the last token hidden state (like XLNet)
625
+ - 'first' => take the first token hidden state (like Bert)
626
+ - 'mean' => take the mean of all tokens hidden states
627
+ - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
628
+ - 'attn' => Not implemented now, use multi-head attention
629
+ summary_use_proj: Add a projection after the vector extraction
630
+ summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
631
+ summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
632
+ summary_first_dropout: Add a dropout before the projection and activation
633
+ summary_last_dropout: Add a dropout after the projection and activation
634
+ """
635
+ def __init__(self, config):
636
+ super(SequenceSummary, self).__init__()
637
+
638
+ self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last'
639
+ if self.summary_type == 'attn':
640
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
641
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
642
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
643
+ raise NotImplementedError
644
+
645
+ self.summary = Identity()
646
+ if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
647
+ if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
648
+ num_classes = config.num_labels
649
+ else:
650
+ num_classes = config.hidden_size
651
+ self.summary = nn.Linear(config.hidden_size, num_classes)
652
+
653
+ self.activation = Identity()
654
+ if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
655
+ self.activation = nn.Tanh()
656
+
657
+ self.first_dropout = Identity()
658
+ if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
659
+ self.first_dropout = nn.Dropout(config.summary_first_dropout)
660
+
661
+ self.last_dropout = Identity()
662
+ if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
663
+ self.last_dropout = nn.Dropout(config.summary_last_dropout)
664
+
665
+ def forward(self, hidden_states, cls_index=None):
666
+ """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
667
+ cls_index: [optional] position of the classification token if summary_type == 'cls_index',
668
+ shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
669
+ if summary_type == 'cls_index' and cls_index is None:
670
+ we take the last token of the sequence as classification token
671
+ """
672
+ if self.summary_type == 'last':
673
+ output = hidden_states[:, -1]
674
+ elif self.summary_type == 'first':
675
+ output = hidden_states[:, 0]
676
+ elif self.summary_type == 'mean':
677
+ output = hidden_states.mean(dim=1)
678
+ elif self.summary_type == 'cls_index':
679
+ if cls_index is None:
680
+ cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2]-1, dtype=torch.long)
681
+ else:
682
+ cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
683
+ cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
684
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
685
+ output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
686
+ elif self.summary_type == 'attn':
687
+ raise NotImplementedError
688
+
689
+ output = self.first_dropout(output)
690
+ output = self.summary(output)
691
+ output = self.activation(output)
692
+ output = self.last_dropout(output)
693
+
694
+ return output
695
+
696
+
697
+ def prune_linear_layer(layer, index, dim=0):
698
+ """ Prune a linear layer (a model parameters) to keep only entries in index.
699
+ Return the pruned layer as a new layer with requires_grad=True.
700
+ Used to remove heads.
701
+ """
702
+ index = index.to(layer.weight.device)
703
+ W = layer.weight.index_select(dim, index).clone().detach()
704
+ if layer.bias is not None:
705
+ if dim == 1:
706
+ b = layer.bias.clone().detach()
707
+ else:
708
+ b = layer.bias[index].clone().detach()
709
+ new_size = list(layer.weight.size())
710
+ new_size[dim] = len(index)
711
+ new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
712
+ new_layer.weight.requires_grad = False
713
+ new_layer.weight.copy_(W.contiguous())
714
+ new_layer.weight.requires_grad = True
715
+ if layer.bias is not None:
716
+ new_layer.bias.requires_grad = False
717
+ new_layer.bias.copy_(b.contiguous())
718
+ new_layer.bias.requires_grad = True
719
+ return new_layer
720
+
721
+
722
+ def prune_conv1d_layer(layer, index, dim=1):
723
+ """ Prune a Conv1D layer (a model parameters) to keep only entries in index.
724
+ A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
725
+ Return the pruned layer as a new layer with requires_grad=True.
726
+ Used to remove heads.
727
+ """
728
+ index = index.to(layer.weight.device)
729
+ W = layer.weight.index_select(dim, index).clone().detach()
730
+ if dim == 0:
731
+ b = layer.bias.clone().detach()
732
+ else:
733
+ b = layer.bias[index].clone().detach()
734
+ new_size = list(layer.weight.size())
735
+ new_size[dim] = len(index)
736
+ new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
737
+ new_layer.weight.requires_grad = False
738
+ new_layer.weight.copy_(W.contiguous())
739
+ new_layer.weight.requires_grad = True
740
+ new_layer.bias.requires_grad = False
741
+ new_layer.bias.copy_(b.contiguous())
742
+ new_layer.bias.requires_grad = True
743
+ return new_layer
744
+
745
+
746
+ def prune_layer(layer, index, dim=None):
747
+ """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
748
+ Return the pruned layer as a new layer with requires_grad=True.
749
+ Used to remove heads.
750
+ """
751
+ if isinstance(layer, nn.Linear):
752
+ return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
753
+ elif isinstance(layer, Conv1D):
754
+ return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
755
+ else:
756
+ raise ValueError("Can't prune layer of class {}".format(layer.__class__))
model/tokenization_albert.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tokenization classes."""
2
+
3
+ from __future__ import (absolute_import, division, print_function,
4
+ unicode_literals)
5
+ import collections
6
+ import unicodedata
7
+ import six
8
+ import logging
9
+ import sentencepiece as spm
10
+
11
+ logger = logging.getLogger(__name__)
12
+ SPIECE_UNDERLINE = u"▁"
13
+
14
+ def preprocess_text(inputs,remove_space=True,do_lower_case=True):
15
+ if remove_space:
16
+ outputs = ' '.join(inputs.strip().split())
17
+ else:
18
+ outputs = inputs
19
+ outputs = outputs.replace("``", '"').replace("''", '"')
20
+ if six.PY2 and isinstance(outputs, str):
21
+ outputs = outputs.decode('utf-8')
22
+ outputs = unicodedata.normalize("NFKD", outputs)
23
+ outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
24
+ if do_lower_case:
25
+ outputs = outputs.lower()
26
+ return outputs
27
+
28
+ def encode_pieces(sp_model, text, return_unicode=True, sample=False):
29
+ """turn sentences into word pieces."""
30
+ # text = preprocess_text(text,)
31
+ if six.PY2 and isinstance(text, unicode):
32
+ text = text.encode('utf-8')
33
+ if not sample:
34
+ pieces = sp_model.EncodeAsPieces(text)
35
+ else:
36
+ pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
37
+ new_pieces = []
38
+ for piece in pieces:
39
+ if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
40
+ cur_pieces = sp_model.EncodeAsPieces(
41
+ piece[:-1].replace(SPIECE_UNDERLINE, ''))
42
+ if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
43
+ if len(cur_pieces[0]) == 1:
44
+ cur_pieces = cur_pieces[1:]
45
+ else:
46
+ cur_pieces[0] = cur_pieces[0][1:]
47
+ cur_pieces.append(piece[-1])
48
+ new_pieces.extend(cur_pieces)
49
+ else:
50
+ new_pieces.append(piece)
51
+
52
+ # note(zhiliny): convert back to unicode for py2
53
+ if six.PY2 and return_unicode:
54
+ ret_pieces = []
55
+ for piece in new_pieces:
56
+ if isinstance(piece, str):
57
+ piece = piece.decode(piece, "utf-8")
58
+ ret_pieces.append(piece)
59
+ new_pieces = ret_pieces
60
+
61
+ return new_pieces
62
+
63
+ def encode_ids(sp_model, text, sample=False):
64
+ pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample)
65
+ ids = [sp_model.PieceToId(piece) for piece in pieces]
66
+ return ids
67
+
68
+
69
+ def load_vocab(vocab_file):
70
+ """Loads a vocabulary file into a dictionary."""
71
+ vocab = collections.OrderedDict()
72
+ with open(vocab_file, "r", encoding="utf-8") as reader:
73
+ tokens = reader.readlines()
74
+ for index, token in enumerate(tokens):
75
+ token = token.rstrip('\n')
76
+ vocab[token] = index
77
+ return vocab
78
+
79
+ def convert_by_vocab(vocab, items):
80
+ """Converts a sequence of [tokens|ids] using the vocab."""
81
+ output = []
82
+ for item in items:
83
+ try:
84
+ output.append(vocab[item])
85
+ except:
86
+ output.append(vocab['[UNK]'])
87
+ return output
88
+
89
+ def convert_tokens_to_ids(vocab, tokens):
90
+ return convert_by_vocab(vocab, tokens)
91
+
92
+ def convert_ids_to_tokens(inv_vocab, ids):
93
+ return convert_by_vocab(inv_vocab, ids)
94
+
95
+ def whitespace_tokenize(text):
96
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
97
+ text = text.strip()
98
+ if not text:
99
+ return []
100
+ tokens = text.split()
101
+ return tokens
102
+
103
+ class FullTokenizer(object):
104
+ """Runs end-to-end tokenziation."""
105
+
106
+ def __init__(self, vocab_file, do_lower_case=True, spm_model_file=None):
107
+ self.vocab = None
108
+ self.sp_model = None
109
+ if spm_model_file:
110
+ self.sp_model = spm.SentencePieceProcessor()
111
+ logger.info("loading sentence piece model")
112
+ self.sp_model.Load(spm_model_file)
113
+
114
+ # # Note(mingdachen): For the purpose of consisent API, we are
115
+ # # generating a vocabulary for the sentence piece tokenizer.
116
+ # self.vocab = {self.sp_model.IdToPiece(i): i for i
117
+ # in range(self.sp_model.GetPieceSize())}
118
+ self.vocab = load_vocab(vocab_file)
119
+ else:
120
+ print("load vocab")
121
+ self.vocab = load_vocab(vocab_file)
122
+ print("load token")
123
+ self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
124
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab,unk_token="[UNK]", max_input_chars_per_word=100)
125
+ self.inv_vocab = {v: k for k, v in self.vocab.items()}
126
+
127
+ def tokenize(self, text):
128
+ if self.sp_model:
129
+ split_tokens = encode_pieces(self.sp_model, text, return_unicode=False)
130
+ else:
131
+ split_tokens = []
132
+ for token in self.basic_tokenizer.tokenize(text):
133
+ for sub_token in self.wordpiece_tokenizer.tokenize(token):
134
+ split_tokens.append(sub_token)
135
+
136
+ return split_tokens
137
+
138
+ def convert_tokens_to_ids(self, tokens):
139
+ if self.sp_model:
140
+ # return [self.sp_model.PieceToId(token) for token in tokens]
141
+ return convert_by_vocab(self.vocab, tokens)
142
+ else:
143
+ return convert_by_vocab(self.vocab, tokens)
144
+
145
+ def convert_ids_to_tokens(self, ids):
146
+ if self.sp_model:
147
+ logger.info("using sentence piece tokenzier.")
148
+ return [self.sp_model.IdToPiece(id_) for id_ in ids]
149
+ else:
150
+ return convert_by_vocab(self.inv_vocab, ids)
151
+
152
+ class BasicTokenizer(object):
153
+ """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
154
+
155
+ def __init__(self, do_lower_case=True):
156
+ """Constructs a BasicTokenizer.
157
+
158
+ Args:
159
+ do_lower_case: Whether to lower case the input.
160
+ """
161
+ self.do_lower_case = do_lower_case
162
+
163
+ def tokenize(self, text):
164
+ """Tokenizes a piece of text."""
165
+ text = self._clean_text(text)
166
+
167
+ # This was added on November 1st, 2018 for the multilingual and Chinese
168
+ # models. This is also applied to the English models now, but it doesn't
169
+ # matter since the English models were not trained on any Chinese data
170
+ # and generally don't have any Chinese data in them (there are Chinese
171
+ # characters in the vocabulary because Wikipedia does have some Chinese
172
+ # words in the English Wikipedia.).
173
+ text = self._tokenize_chinese_chars(text)
174
+ orig_tokens = whitespace_tokenize(text)
175
+ split_tokens = []
176
+ for token in orig_tokens:
177
+ if self.do_lower_case:
178
+ token = token.lower()
179
+ token = self._run_strip_accents(token)
180
+ split_tokens.extend(self._run_split_on_punc(token))
181
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
182
+ return output_tokens
183
+
184
+ def _run_strip_accents(self, text):
185
+ """Strips accents from a piece of text."""
186
+ text = unicodedata.normalize("NFD", text)
187
+ output = []
188
+ for char in text:
189
+ cat = unicodedata.category(char)
190
+ if cat == "Mn":
191
+ continue
192
+ output.append(char)
193
+ return "".join(output)
194
+
195
+ def _run_split_on_punc(self, text):
196
+ """Splits punctuation on a piece of text."""
197
+ chars = list(text)
198
+ i = 0
199
+ start_new_word = True
200
+ output = []
201
+ while i < len(chars):
202
+ char = chars[i]
203
+ if _is_punctuation(char):
204
+ output.append([char])
205
+ start_new_word = True
206
+ else:
207
+ if start_new_word:
208
+ output.append([])
209
+ start_new_word = False
210
+ output[-1].append(char)
211
+ i += 1
212
+
213
+ return ["".join(x) for x in output]
214
+
215
+ def _tokenize_chinese_chars(self, text):
216
+ """Adds whitespace around any CJK character."""
217
+ output = []
218
+ for char in text:
219
+ cp = ord(char)
220
+ if self._is_chinese_char(cp):
221
+ output.append(" ")
222
+ output.append(char)
223
+ output.append(" ")
224
+ else:
225
+ output.append(char)
226
+ return "".join(output)
227
+
228
+ def _is_chinese_char(self, cp):
229
+ """Checks whether CP is the codepoint of a CJK character."""
230
+ # This defines a "chinese character" as anything in the CJK Unicode block:
231
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
232
+ #
233
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
234
+ # despite its name. The modern Korean Hangul alphabet is a different block,
235
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
236
+ # space-separated words, so they are not treated specially and handled
237
+ # like the all of the other languages.
238
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
239
+ (cp >= 0x3400 and cp <= 0x4DBF) or #
240
+ (cp >= 0x20000 and cp <= 0x2A6DF) or #
241
+ (cp >= 0x2A700 and cp <= 0x2B73F) or #
242
+ (cp >= 0x2B740 and cp <= 0x2B81F) or #
243
+ (cp >= 0x2B820 and cp <= 0x2CEAF) or
244
+ (cp >= 0xF900 and cp <= 0xFAFF) or #
245
+ (cp >= 0x2F800 and cp <= 0x2FA1F)): #
246
+ return True
247
+
248
+ return False
249
+
250
+ def _clean_text(self, text):
251
+ """Performs invalid character removal and whitespace cleanup on text."""
252
+ output = []
253
+ for char in text:
254
+ cp = ord(char)
255
+ if cp == 0 or cp == 0xfffd or _is_control(char):
256
+ continue
257
+ if _is_whitespace(char):
258
+ output.append(" ")
259
+ else:
260
+ output.append(char)
261
+ return "".join(output)
262
+
263
+ class WordpieceTokenizer(object):
264
+ """Runs WordPiece tokenization."""
265
+
266
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
267
+ self.vocab = vocab
268
+ self.unk_token = unk_token
269
+ self.max_input_chars_per_word = max_input_chars_per_word
270
+
271
+ def tokenize(self, text):
272
+ """Tokenizes a piece of text into its word pieces.
273
+
274
+ This uses a greedy longest-match-first algorithm to perform tokenization
275
+ using the given vocabulary.
276
+
277
+ For example:
278
+ input = "unaffable"
279
+ output = ["un", "##aff", "##able"]
280
+
281
+ Args:
282
+ text: A single token or whitespace separated tokens. This should have
283
+ already been passed through `BasicTokenizer`.
284
+
285
+ Returns:
286
+ A list of wordpiece tokens.
287
+ """
288
+
289
+ output_tokens = []
290
+ for token in whitespace_tokenize(text):
291
+ chars = list(token)
292
+ if len(chars) > self.max_input_chars_per_word:
293
+ output_tokens.append(self.unk_token)
294
+ continue
295
+
296
+ is_bad = False
297
+ start = 0
298
+ sub_tokens = []
299
+ while start < len(chars):
300
+ end = len(chars)
301
+ cur_substr = None
302
+ while start < end:
303
+ substr = "".join(chars[start:end])
304
+ if start > 0:
305
+ substr = "##" + substr
306
+ if substr in self.vocab:
307
+ cur_substr = substr
308
+ break
309
+ end -= 1
310
+ if cur_substr is None:
311
+ is_bad = True
312
+ break
313
+ sub_tokens.append(cur_substr)
314
+ start = end
315
+
316
+ if is_bad:
317
+ output_tokens.append(self.unk_token)
318
+ else:
319
+ output_tokens.extend(sub_tokens)
320
+ return output_tokens
321
+
322
+ def _is_whitespace(char):
323
+ """Checks whether `chars` is a whitespace character."""
324
+ # \t, \n, and \r are technically control characters but we treat them
325
+ # as whitespace since they are generally considered as such.
326
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
327
+ return True
328
+ cat = unicodedata.category(char)
329
+ if cat == "Zs":
330
+ return True
331
+ return False
332
+
333
+
334
+ def _is_control(char):
335
+ """Checks whether `chars` is a control character."""
336
+ # These are technically control characters but we count them as whitespace
337
+ # characters.
338
+ if char == "\t" or char == "\n" or char == "\r":
339
+ return False
340
+ cat = unicodedata.category(char)
341
+ if cat in ("Cc", "Cf"):
342
+ return True
343
+ return False
344
+
345
+ def _is_punctuation(char):
346
+ """Checks whether `chars` is a punctuation character."""
347
+ cp = ord(char)
348
+ # We treat all non-letter/number ASCII as punctuation.
349
+ # Characters such as "^", "$", and "`" are not in the Unicode
350
+ # Punctuation class but we treat them as punctuation anyways, for
351
+ # consistency.
352
+ if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
353
+ (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
354
+ return True
355
+ cat = unicodedata.category(char)
356
+ if cat.startswith("P"):
357
+ return True
358
+ return False
model/tokenization_bert.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes."""
16
+
17
+ from __future__ import absolute_import, division, print_function, unicode_literals
18
+
19
+ import collections
20
+ import logging
21
+ import os
22
+ import unicodedata
23
+ from io import open
24
+
25
+ from .tokenization_utils import PreTrainedTokenizer
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
30
+
31
+ def load_vocab(vocab_file):
32
+ """Loads a vocabulary file into a dictionary."""
33
+ vocab = collections.OrderedDict()
34
+ with open(vocab_file, "r", encoding="utf-8") as reader:
35
+ tokens = reader.readlines()
36
+ for index, token in enumerate(tokens):
37
+ token = token.rstrip('\n')
38
+ vocab[token] = index
39
+ return vocab
40
+
41
+
42
+ def whitespace_tokenize(text):
43
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
44
+ text = text.strip()
45
+ if not text:
46
+ return []
47
+ tokens = text.split()
48
+ return tokens
49
+
50
+
51
+ class BertTokenizer(PreTrainedTokenizer):
52
+ r"""
53
+ Constructs a BertTokenizer.
54
+ :class:`~transformers.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
55
+
56
+ Args:
57
+ vocab_file: Path to a one-wordpiece-per-line vocabulary file
58
+ do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False
59
+ do_basic_tokenize: Whether to do basic tokenization before wordpiece.
60
+ max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the
61
+ minimum of this value (if specified) and the underlying BERT model's sequence length.
62
+ never_split: List of tokens which will never be split during tokenization. Only has an effect when
63
+ do_wordpiece_only=False
64
+ """
65
+
66
+ vocab_files_names = VOCAB_FILES_NAMES
67
+
68
+ def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
69
+ unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
70
+ mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs):
71
+ """Constructs a BertTokenizer.
72
+
73
+ Args:
74
+ **vocab_file**: Path to a one-wordpiece-per-line vocabulary file
75
+ **do_lower_case**: (`optional`) boolean (default True)
76
+ Whether to lower case the input
77
+ Only has an effect when do_basic_tokenize=True
78
+ **do_basic_tokenize**: (`optional`) boolean (default True)
79
+ Whether to do basic tokenization before wordpiece.
80
+ **never_split**: (`optional`) list of string
81
+ List of tokens which will never be split during tokenization.
82
+ Only has an effect when do_basic_tokenize=True
83
+ **tokenize_chinese_chars**: (`optional`) boolean (default True)
84
+ Whether to tokenize Chinese characters.
85
+ This should likely be deactivated for Japanese:
86
+ see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
87
+ """
88
+ super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
89
+ pad_token=pad_token, cls_token=cls_token,
90
+ mask_token=mask_token, **kwargs)
91
+ self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
92
+ self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
93
+
94
+ if not os.path.isfile(vocab_file):
95
+ raise ValueError(
96
+ "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
97
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
98
+ self.vocab = load_vocab(vocab_file)
99
+ self.ids_to_tokens = collections.OrderedDict(
100
+ [(ids, tok) for tok, ids in self.vocab.items()])
101
+ self.do_basic_tokenize = do_basic_tokenize
102
+ if do_basic_tokenize:
103
+ self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
104
+ never_split=never_split,
105
+ tokenize_chinese_chars=tokenize_chinese_chars)
106
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
107
+
108
+ @property
109
+ def vocab_size(self):
110
+ return len(self.vocab)
111
+
112
+ def _tokenize(self, text):
113
+ split_tokens = []
114
+ if self.do_basic_tokenize:
115
+ for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
116
+ for sub_token in self.wordpiece_tokenizer.tokenize(token):
117
+ split_tokens.append(sub_token)
118
+ else:
119
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
120
+ return split_tokens
121
+
122
+ def _convert_token_to_id(self, token):
123
+ """ Converts a token (str/unicode) in an id using the vocab. """
124
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
125
+
126
+ def _convert_id_to_token(self, index):
127
+ """Converts an index (integer) in a token (string/unicode) using the vocab."""
128
+ return self.ids_to_tokens.get(index, self.unk_token)
129
+
130
+ def convert_tokens_to_string(self, tokens):
131
+ """ Converts a sequence of tokens (string) in a single string. """
132
+ out_string = ' '.join(tokens).replace(' ##', '').strip()
133
+ return out_string
134
+
135
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
136
+ """
137
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
138
+ by concatenating and adding special tokens.
139
+ A BERT sequence has the following format:
140
+ single sequence: [CLS] X [SEP]
141
+ pair of sequences: [CLS] A [SEP] B [SEP]
142
+ """
143
+ if token_ids_1 is None:
144
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
145
+ cls = [self.cls_token_id]
146
+ sep = [self.sep_token_id]
147
+ return cls + token_ids_0 + sep + token_ids_1 + sep
148
+
149
+ def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
150
+ """
151
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
152
+ special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
153
+
154
+ Args:
155
+ token_ids_0: list of ids (must not contain special tokens)
156
+ token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
157
+ for sequence pairs
158
+ already_has_special_tokens: (default False) Set to True if the token list is already formated with
159
+ special tokens for the model
160
+
161
+ Returns:
162
+ A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
163
+ """
164
+
165
+ if already_has_special_tokens:
166
+ if token_ids_1 is not None:
167
+ raise ValueError("You should not supply a second sequence if the provided sequence of "
168
+ "ids is already formated with special tokens for the model.")
169
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
170
+
171
+ if token_ids_1 is not None:
172
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
173
+ return [1] + ([0] * len(token_ids_0)) + [1]
174
+
175
+ def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
176
+ """
177
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
178
+ A BERT sequence pair mask has the following format:
179
+ 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
180
+ | first sequence | second sequence
181
+
182
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
183
+ """
184
+ sep = [self.sep_token_id]
185
+ cls = [self.cls_token_id]
186
+ if token_ids_1 is None:
187
+ return len(cls + token_ids_0 + sep) * [0]
188
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
189
+
190
+ def save_vocabulary(self, vocab_path):
191
+ """Save the tokenizer vocabulary to a directory or file."""
192
+ index = 0
193
+ if os.path.isdir(vocab_path):
194
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
195
+ else:
196
+ vocab_file = vocab_path
197
+ with open(vocab_file, "w", encoding="utf-8") as writer:
198
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
199
+ if index != token_index:
200
+ logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
201
+ " Please check that the vocabulary is not corrupted!".format(vocab_file))
202
+ index = token_index
203
+ writer.write(token + u'\n')
204
+ index += 1
205
+ return (vocab_file,)
206
+
207
+
208
+ class BasicTokenizer(object):
209
+ """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
210
+
211
+ def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
212
+ """ Constructs a BasicTokenizer.
213
+
214
+ Args:
215
+ **do_lower_case**: Whether to lower case the input.
216
+ **never_split**: (`optional`) list of str
217
+ Kept for backward compatibility purposes.
218
+ Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
219
+ List of token not to split.
220
+ **tokenize_chinese_chars**: (`optional`) boolean (default True)
221
+ Whether to tokenize Chinese characters.
222
+ This should likely be deactivated for Japanese:
223
+ see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
224
+ """
225
+ if never_split is None:
226
+ never_split = []
227
+ self.do_lower_case = do_lower_case
228
+ self.never_split = never_split
229
+ self.tokenize_chinese_chars = tokenize_chinese_chars
230
+
231
+ def tokenize(self, text, never_split=None):
232
+ """ Basic Tokenization of a piece of text.
233
+ Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
234
+
235
+ Args:
236
+ **never_split**: (`optional`) list of str
237
+ Kept for backward compatibility purposes.
238
+ Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
239
+ List of token not to split.
240
+ """
241
+ never_split = self.never_split + (never_split if never_split is not None else [])
242
+ text = self._clean_text(text)
243
+ # This was added on November 1st, 2018 for the multilingual and Chinese
244
+ # models. This is also applied to the English models now, but it doesn't
245
+ # matter since the English models were not trained on any Chinese data
246
+ # and generally don't have any Chinese data in them (there are Chinese
247
+ # characters in the vocabulary because Wikipedia does have some Chinese
248
+ # words in the English Wikipedia.).
249
+ if self.tokenize_chinese_chars:
250
+ text = self._tokenize_chinese_chars(text)
251
+ orig_tokens = whitespace_tokenize(text)
252
+ split_tokens = []
253
+ for token in orig_tokens:
254
+ if self.do_lower_case and token not in never_split:
255
+ token = token.lower()
256
+ token = self._run_strip_accents(token)
257
+ split_tokens.extend(self._run_split_on_punc(token))
258
+
259
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
260
+ return output_tokens
261
+
262
+ def _run_strip_accents(self, text):
263
+ """Strips accents from a piece of text."""
264
+ text = unicodedata.normalize("NFD", text)
265
+ output = []
266
+ for char in text:
267
+ cat = unicodedata.category(char)
268
+ if cat == "Mn":
269
+ continue
270
+ output.append(char)
271
+ return "".join(output)
272
+
273
+ def _run_split_on_punc(self, text, never_split=None):
274
+ """Splits punctuation on a piece of text."""
275
+ if never_split is not None and text in never_split:
276
+ return [text]
277
+ chars = list(text)
278
+ i = 0
279
+ start_new_word = True
280
+ output = []
281
+ while i < len(chars):
282
+ char = chars[i]
283
+ if _is_punctuation(char):
284
+ output.append([char])
285
+ start_new_word = True
286
+ else:
287
+ if start_new_word:
288
+ output.append([])
289
+ start_new_word = False
290
+ output[-1].append(char)
291
+ i += 1
292
+
293
+ return ["".join(x) for x in output]
294
+
295
+ def _tokenize_chinese_chars(self, text):
296
+ """Adds whitespace around any CJK character."""
297
+ output = []
298
+ for char in text:
299
+ cp = ord(char)
300
+ if self._is_chinese_char(cp):
301
+ output.append(" ")
302
+ output.append(char)
303
+ output.append(" ")
304
+ else:
305
+ output.append(char)
306
+ return "".join(output)
307
+
308
+ def _is_chinese_char(self, cp):
309
+ """Checks whether CP is the codepoint of a CJK character."""
310
+ # This defines a "chinese character" as anything in the CJK Unicode block:
311
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
312
+ #
313
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
314
+ # despite its name. The modern Korean Hangul alphabet is a different block,
315
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
316
+ # space-separated words, so they are not treated specially and handled
317
+ # like the all of the other languages.
318
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
319
+ (cp >= 0x3400 and cp <= 0x4DBF) or #
320
+ (cp >= 0x20000 and cp <= 0x2A6DF) or #
321
+ (cp >= 0x2A700 and cp <= 0x2B73F) or #
322
+ (cp >= 0x2B740 and cp <= 0x2B81F) or #
323
+ (cp >= 0x2B820 and cp <= 0x2CEAF) or
324
+ (cp >= 0xF900 and cp <= 0xFAFF) or #
325
+ (cp >= 0x2F800 and cp <= 0x2FA1F)): #
326
+ return True
327
+
328
+ return False
329
+
330
+ def _clean_text(self, text):
331
+ """Performs invalid character removal and whitespace cleanup on text."""
332
+ output = []
333
+ for char in text:
334
+ cp = ord(char)
335
+ if cp == 0 or cp == 0xfffd or _is_control(char):
336
+ continue
337
+ if _is_whitespace(char):
338
+ output.append(" ")
339
+ else:
340
+ output.append(char)
341
+ return "".join(output)
342
+
343
+
344
+ class WordpieceTokenizer(object):
345
+ """Runs WordPiece tokenization."""
346
+
347
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
348
+ self.vocab = vocab
349
+ self.unk_token = unk_token
350
+ self.max_input_chars_per_word = max_input_chars_per_word
351
+
352
+ def tokenize(self, text):
353
+ """Tokenizes a piece of text into its word pieces.
354
+
355
+ This uses a greedy longest-match-first algorithm to perform tokenization
356
+ using the given vocabulary.
357
+
358
+ For example:
359
+ input = "unaffable"
360
+ output = ["un", "##aff", "##able"]
361
+
362
+ Args:
363
+ text: A single token or whitespace separated tokens. This should have
364
+ already been passed through `BasicTokenizer`.
365
+
366
+ Returns:
367
+ A list of wordpiece tokens.
368
+ """
369
+
370
+ output_tokens = []
371
+ for token in whitespace_tokenize(text):
372
+ chars = list(token)
373
+ if len(chars) > self.max_input_chars_per_word:
374
+ output_tokens.append(self.unk_token)
375
+ continue
376
+
377
+ is_bad = False
378
+ start = 0
379
+ sub_tokens = []
380
+ while start < len(chars):
381
+ end = len(chars)
382
+ cur_substr = None
383
+ while start < end:
384
+ substr = "".join(chars[start:end])
385
+ if start > 0:
386
+ substr = "##" + substr
387
+ if substr in self.vocab:
388
+ cur_substr = substr
389
+ break
390
+ end -= 1
391
+ if cur_substr is None:
392
+ is_bad = True
393
+ break
394
+ sub_tokens.append(cur_substr)
395
+ start = end
396
+
397
+ if is_bad:
398
+ output_tokens.append(self.unk_token)
399
+ else:
400
+ output_tokens.extend(sub_tokens)
401
+ return output_tokens
402
+
403
+
404
+ def _is_whitespace(char):
405
+ """Checks whether `chars` is a whitespace character."""
406
+ # \t, \n, and \r are technically contorl characters but we treat them
407
+ # as whitespace since they are generally considered as such.
408
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
409
+ return True
410
+ cat = unicodedata.category(char)
411
+ if cat == "Zs":
412
+ return True
413
+ return False
414
+
415
+
416
+ def _is_control(char):
417
+ """Checks whether `chars` is a control character."""
418
+ # These are technically control characters but we count them as whitespace
419
+ # characters.
420
+ if char == "\t" or char == "\n" or char == "\r":
421
+ return False
422
+ cat = unicodedata.category(char)
423
+ if cat.startswith("C"):
424
+ return True
425
+ return False
426
+
427
+
428
+ def _is_punctuation(char):
429
+ """Checks whether `chars` is a punctuation character."""
430
+ cp = ord(char)
431
+ # We treat all non-letter/number ASCII as punctuation.
432
+ # Characters such as "^", "$", and "`" are not in the Unicode
433
+ # Punctuation class but we treat them as punctuation anyways, for
434
+ # consistency.
435
+ if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
436
+ (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
437
+ return True
438
+ cat = unicodedata.category(char)
439
+ if cat.startswith("P"):
440
+ return True
441
+ return False
model/tokenization_utils.py ADDED
@@ -0,0 +1,1065 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for OpenAI GPT."""
16
+ from __future__ import (absolute_import, division, print_function,
17
+ unicode_literals)
18
+
19
+ import logging
20
+ import os
21
+ import json
22
+ import six
23
+ import copy
24
+ from io import open
25
+
26
+ from .file_utils import cached_path
27
+
28
+ import torch
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ SPECIAL_TOKENS_MAP_FILE = 'special_tokens_map.json'
33
+ ADDED_TOKENS_FILE = 'added_tokens.json'
34
+ TOKENIZER_CONFIG_FILE = 'tokenizer_config.json'
35
+
36
+ class PreTrainedTokenizer(object):
37
+ """ Base class for all tokenizers.
38
+ Handle all the shared methods for tokenization and special tokens as well as methods dowloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
39
+
40
+ This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
41
+
42
+ Class attributes (overridden by derived classes):
43
+
44
+ - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file required by the model, and as associated values, the filename for saving the associated file (string).
45
+ - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the associated pretrained vocabulary file.
46
+ - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, the maximum length of the sequence inputs of this model, or None if the model has no maximum input size.
47
+ - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained models, and as associated values, a dictionnary of specific arguments to pass to the ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the ``from_pretrained()`` method.
48
+
49
+ Parameters:
50
+
51
+ - ``bos_token``: (`Optional`) string: a beginning of sentence token. Will be associated to ``self.bos_token`` and ``self.bos_token_id``
52
+
53
+ - ``eos_token``: (`Optional`) string: an end of sentence token. Will be associated to ``self.eos_token`` and ``self.eos_token_id``
54
+
55
+ - ``unk_token``: (`Optional`) string: an unknown token. Will be associated to ``self.unk_token`` and ``self.unk_token_id``
56
+
57
+ - ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence). Will be associated to ``self.sep_token`` and ``self.sep_token_id``
58
+
59
+ - ``pad_token``: (`Optional`) string: a padding token. Will be associated to ``self.pad_token`` and ``self.pad_token_id``
60
+
61
+ - ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model). Will be associated to ``self.cls_token`` and ``self.cls_token_id``
62
+
63
+ - ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id``
64
+
65
+ - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens. Adding all special tokens here ensure they won't be split by the tokenization process. Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``
66
+ """
67
+ vocab_files_names = {}
68
+ pretrained_vocab_files_map = {}
69
+ pretrained_init_configuration = {}
70
+ max_model_input_sizes = {}
71
+
72
+ SPECIAL_TOKENS_ATTRIBUTES = ["bos_token", "eos_token", "unk_token", "sep_token",
73
+ "pad_token", "cls_token", "mask_token",
74
+ "additional_special_tokens"]
75
+
76
+ @property
77
+ def bos_token(self):
78
+ """ Beginning of sentence token (string). Log an error if used while not having been set. """
79
+ if self._bos_token is None:
80
+ logger.error("Using bos_token, but it is not set yet.")
81
+ return self._bos_token
82
+
83
+ @property
84
+ def eos_token(self):
85
+ """ End of sentence token (string). Log an error if used while not having been set. """
86
+ if self._eos_token is None:
87
+ logger.error("Using eos_token, but it is not set yet.")
88
+ return self._eos_token
89
+
90
+ @property
91
+ def unk_token(self):
92
+ """ Unknown token (string). Log an error if used while not having been set. """
93
+ if self._unk_token is None:
94
+ logger.error("Using unk_token, but it is not set yet.")
95
+ return self._unk_token
96
+
97
+ @property
98
+ def sep_token(self):
99
+ """ Separation token (string). E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
100
+ if self._sep_token is None:
101
+ logger.error("Using sep_token, but it is not set yet.")
102
+ return self._sep_token
103
+
104
+ @property
105
+ def pad_token(self):
106
+ """ Padding token (string). Log an error if used while not having been set. """
107
+ if self._pad_token is None:
108
+ logger.error("Using pad_token, but it is not set yet.")
109
+ return self._pad_token
110
+
111
+ @property
112
+ def cls_token(self):
113
+ """ Classification token (string). E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
114
+ if self._cls_token is None:
115
+ logger.error("Using cls_token, but it is not set yet.")
116
+ return self._cls_token
117
+
118
+ @property
119
+ def mask_token(self):
120
+ """ Mask token (string). E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
121
+ if self._mask_token is None:
122
+ logger.error("Using mask_token, but it is not set yet.")
123
+ return self._mask_token
124
+
125
+ @property
126
+ def additional_special_tokens(self):
127
+ """ All the additional special tokens you may want to use (list of strings). Log an error if used while not having been set. """
128
+ if self._additional_special_tokens is None:
129
+ logger.error("Using additional_special_tokens, but it is not set yet.")
130
+ return self._additional_special_tokens
131
+
132
+ @bos_token.setter
133
+ def bos_token(self, value):
134
+ self._bos_token = value
135
+
136
+ @eos_token.setter
137
+ def eos_token(self, value):
138
+ self._eos_token = value
139
+
140
+ @unk_token.setter
141
+ def unk_token(self, value):
142
+ self._unk_token = value
143
+
144
+ @sep_token.setter
145
+ def sep_token(self, value):
146
+ self._sep_token = value
147
+
148
+ @pad_token.setter
149
+ def pad_token(self, value):
150
+ self._pad_token = value
151
+
152
+ @cls_token.setter
153
+ def cls_token(self, value):
154
+ self._cls_token = value
155
+
156
+ @mask_token.setter
157
+ def mask_token(self, value):
158
+ self._mask_token = value
159
+
160
+ @additional_special_tokens.setter
161
+ def additional_special_tokens(self, value):
162
+ self._additional_special_tokens = value
163
+
164
+ @property
165
+ def bos_token_id(self):
166
+ """ Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
167
+ return self.convert_tokens_to_ids(self.bos_token)
168
+
169
+ @property
170
+ def eos_token_id(self):
171
+ """ Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
172
+ return self.convert_tokens_to_ids(self.eos_token)
173
+
174
+ @property
175
+ def unk_token_id(self):
176
+ """ Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
177
+ return self.convert_tokens_to_ids(self.unk_token)
178
+
179
+ @property
180
+ def sep_token_id(self):
181
+ """ Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
182
+ return self.convert_tokens_to_ids(self.sep_token)
183
+
184
+ @property
185
+ def pad_token_id(self):
186
+ """ Id of the padding token in the vocabulary. Log an error if used while not having been set. """
187
+ return self.convert_tokens_to_ids(self.pad_token)
188
+
189
+ @property
190
+ def cls_token_id(self):
191
+ """ Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
192
+ return self.convert_tokens_to_ids(self.cls_token)
193
+
194
+ @property
195
+ def mask_token_id(self):
196
+ """ Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
197
+ return self.convert_tokens_to_ids(self.mask_token)
198
+
199
+ @property
200
+ def additional_special_tokens_ids(self):
201
+ """ Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """
202
+ return self.convert_tokens_to_ids(self.additional_special_tokens)
203
+
204
+ def __init__(self, max_len=None, **kwargs):
205
+ self._bos_token = None
206
+ self._eos_token = None
207
+ self._unk_token = None
208
+ self._sep_token = None
209
+ self._pad_token = None
210
+ self._cls_token = None
211
+ self._mask_token = None
212
+ self._additional_special_tokens = []
213
+
214
+ self.max_len = max_len if max_len is not None else int(1e12)
215
+
216
+ # Added tokens
217
+ self.added_tokens_encoder = {}
218
+ self.added_tokens_decoder = {}
219
+
220
+ # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
221
+ self.init_inputs = ()
222
+ self.init_kwargs = {}
223
+
224
+ for key, value in kwargs.items():
225
+ if key in self.SPECIAL_TOKENS_ATTRIBUTES:
226
+ if key == 'additional_special_tokens':
227
+ assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
228
+ else:
229
+ assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
230
+ setattr(self, key, value)
231
+
232
+
233
+ @classmethod
234
+ def from_pretrained(cls, *inputs, **kwargs):
235
+ r"""
236
+ Instantiate a :class:`~transformers.PreTrainedTokenizer` (or a derived class) from a predefined tokenizer.
237
+
238
+ Args:
239
+ pretrained_model_name_or_path: either:
240
+
241
+ - a string with the `shortcut name` of a predefined tokenizer to load from cache or download, e.g.: ``bert-base-uncased``.
242
+ - a path to a `directory` containing vocabulary files required by the tokenizer, for instance saved using the :func:`~transformers.PreTrainedTokenizer.save_pretrained` method, e.g.: ``./my_model_directory/``.
243
+ - (not applicable to all derived classes) a path or url to a single saved vocabulary file if and only if the tokenizer only requires a single vocabulary file (e.g. Bert, XLNet), e.g.: ``./my_model_directory/vocab.txt``.
244
+
245
+ cache_dir: (`optional`) string:
246
+ Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the standard cache should not be used.
247
+
248
+ force_download: (`optional`) boolean, default False:
249
+ Force to (re-)download the vocabulary files and override the cached versions if they exists.
250
+
251
+ proxies: (`optional`) dict, default None:
252
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
253
+ The proxies are used on each request.
254
+
255
+ inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.
256
+
257
+ kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~transformers.PreTrainedTokenizer` for details.
258
+
259
+ Examples::
260
+
261
+ # We can't instantiate directly the base class `PreTrainedTokenizer` so let's show our examples on a derived class: BertTokenizer
262
+
263
+ # Download vocabulary from S3 and cache.
264
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
265
+
266
+ # If vocabulary files are in a directory (e.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`)
267
+ tokenizer = BertTokenizer.from_pretrained('./test/saved_model/')
268
+
269
+ # If the tokenizer uses a single vocabulary file, you can point directly to this file
270
+ tokenizer = BertTokenizer.from_pretrained('./test/saved_model/my_vocab.txt')
271
+
272
+ # You can link tokens to special vocabulary when instantiating
273
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', unk_token='<unk>')
274
+ # You should be sure '<unk>' is in the vocabulary when doing that.
275
+ # Otherwise use tokenizer.add_special_tokens({'unk_token': '<unk>'}) instead)
276
+ assert tokenizer.unk_token == '<unk>'
277
+
278
+ """
279
+ return cls._from_pretrained(*inputs, **kwargs)
280
+
281
+
282
+ @classmethod
283
+ def _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
284
+ cache_dir = kwargs.pop('cache_dir', None)
285
+ force_download = kwargs.pop('force_download', False)
286
+ proxies = kwargs.pop('proxies', None)
287
+
288
+ s3_models = list(cls.max_model_input_sizes.keys())
289
+ vocab_files = {}
290
+ init_configuration = {}
291
+ if pretrained_model_name_or_path in s3_models:
292
+ # Get the vocabulary from AWS S3 bucket
293
+ for file_id, map_list in cls.pretrained_vocab_files_map.items():
294
+ vocab_files[file_id] = map_list[pretrained_model_name_or_path]
295
+ if cls.pretrained_init_configuration and pretrained_model_name_or_path in cls.pretrained_init_configuration:
296
+ init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path]
297
+ else:
298
+ # Get the vocabulary from local files
299
+ logger.info(
300
+ "Model name '{}' not found in model shortcut name list ({}). "
301
+ "Assuming '{}' is a path or url to a directory containing tokenizer files.".format(
302
+ pretrained_model_name_or_path, ', '.join(s3_models),
303
+ pretrained_model_name_or_path))
304
+
305
+ # Look for the tokenizer main vocabulary files
306
+ for file_id, file_name in cls.vocab_files_names.items():
307
+ if os.path.isdir(pretrained_model_name_or_path):
308
+ # If a directory is provided we look for the standard filenames
309
+ full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
310
+ else:
311
+ # If a path to a file is provided we use it (will only work for non-BPE tokenizer using a single vocabulary file)
312
+ full_file_name = pretrained_model_name_or_path
313
+ if not os.path.exists(full_file_name):
314
+ logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
315
+ full_file_name = None
316
+ vocab_files[file_id] = full_file_name
317
+
318
+ # Look for the additional tokens files
319
+ additional_files_names = {'added_tokens_file': ADDED_TOKENS_FILE,
320
+ 'special_tokens_map_file': SPECIAL_TOKENS_MAP_FILE,
321
+ 'tokenizer_config_file': TOKENIZER_CONFIG_FILE,
322
+ }
323
+
324
+ # If a path to a file was provided, get the parent directory
325
+ saved_directory = pretrained_model_name_or_path
326
+ if os.path.exists(saved_directory) and not os.path.isdir(saved_directory):
327
+ saved_directory = os.path.dirname(saved_directory)
328
+
329
+ for file_id, file_name in additional_files_names.items():
330
+ full_file_name = os.path.join(saved_directory, file_name)
331
+ if not os.path.exists(full_file_name):
332
+ logger.info("Didn't find file {}. We won't load it.".format(full_file_name))
333
+ full_file_name = None
334
+ vocab_files[file_id] = full_file_name
335
+
336
+ if all(full_file_name is None for full_file_name in vocab_files.values()):
337
+ raise EnvironmentError(
338
+ "Model name '{}' was not found in tokenizers model name list ({}). "
339
+ "We assumed '{}' was a path or url to a directory containing vocabulary files "
340
+ "named {} but couldn't find such vocabulary files at this path or url.".format(
341
+ pretrained_model_name_or_path, ', '.join(s3_models),
342
+ pretrained_model_name_or_path,
343
+ list(cls.vocab_files_names.values())))
344
+
345
+ # Get files from url, cache, or disk depending on the case
346
+ try:
347
+ resolved_vocab_files = {}
348
+ for file_id, file_path in vocab_files.items():
349
+ if file_path is None:
350
+ resolved_vocab_files[file_id] = None
351
+ else:
352
+ resolved_vocab_files[file_id] = cached_path(file_path, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
353
+ except EnvironmentError:
354
+ if pretrained_model_name_or_path in s3_models:
355
+ msg = "Couldn't reach server at '{}' to download vocabulary files."
356
+ else:
357
+ msg = "Model name '{}' was not found in tokenizers model name list ({}). " \
358
+ "We assumed '{}' was a path or url to a directory containing vocabulary files " \
359
+ "named {}, but couldn't find such vocabulary files at this path or url.".format(
360
+ pretrained_model_name_or_path, ', '.join(s3_models),
361
+ pretrained_model_name_or_path,
362
+ list(cls.vocab_files_names.values()))
363
+
364
+ raise EnvironmentError(msg)
365
+
366
+ for file_id, file_path in vocab_files.items():
367
+ if file_path == resolved_vocab_files[file_id]:
368
+ logger.info("loading file {}".format(file_path))
369
+ else:
370
+ logger.info("loading file {} from cache at {}".format(
371
+ file_path, resolved_vocab_files[file_id]))
372
+
373
+ # Prepare tokenizer initialization kwargs
374
+ # Did we saved some inputs and kwargs to reload ?
375
+ tokenizer_config_file = resolved_vocab_files.pop('tokenizer_config_file', None)
376
+ if tokenizer_config_file is not None:
377
+ init_kwargs = json.load(open(tokenizer_config_file, encoding="utf-8"))
378
+ saved_init_inputs = init_kwargs.pop('init_inputs', ())
379
+ if not init_inputs:
380
+ init_inputs = saved_init_inputs
381
+ else:
382
+ init_kwargs = init_configuration
383
+
384
+ # Update with newly provided kwargs
385
+ init_kwargs.update(kwargs)
386
+
387
+ # Set max length if needed
388
+ if pretrained_model_name_or_path in cls.max_model_input_sizes:
389
+ # if we're using a pretrained model, ensure the tokenizer
390
+ # wont index sequences longer than the number of positional embeddings
391
+ max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
392
+ if max_len is not None and isinstance(max_len, (int, float)):
393
+ init_kwargs['max_len'] = min(init_kwargs.get('max_len', int(1e12)), max_len)
394
+
395
+ # Merge resolved_vocab_files arguments in init_kwargs.
396
+ added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
397
+ special_tokens_map_file = resolved_vocab_files.pop('special_tokens_map_file', None)
398
+ for args_name, file_path in resolved_vocab_files.items():
399
+ if args_name not in init_kwargs:
400
+ init_kwargs[args_name] = file_path
401
+ if special_tokens_map_file is not None:
402
+ special_tokens_map = json.load(open(special_tokens_map_file, encoding="utf-8"))
403
+ for key, value in special_tokens_map.items():
404
+ if key not in init_kwargs:
405
+ init_kwargs[key] = value
406
+
407
+ # Instantiate tokenizer.
408
+ tokenizer = cls(*init_inputs, **init_kwargs)
409
+
410
+ # Save inputs and kwargs for saving and re-loading with ``save_pretrained``
411
+ tokenizer.init_inputs = init_inputs
412
+ tokenizer.init_kwargs = init_kwargs
413
+
414
+ # Add supplementary tokens.
415
+ if added_tokens_file is not None:
416
+ added_tok_encoder = json.load(open(added_tokens_file, encoding="utf-8"))
417
+ added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
418
+ tokenizer.added_tokens_encoder.update(added_tok_encoder)
419
+ tokenizer.added_tokens_decoder.update(added_tok_decoder)
420
+
421
+ return tokenizer
422
+
423
+
424
+ def save_pretrained(self, save_directory):
425
+ """ Save the tokenizer vocabulary files together with:
426
+ - added tokens,
427
+ - special-tokens-to-class-attributes-mapping,
428
+ - tokenizer instantiation positional and keywords inputs (e.g. do_lower_case for Bert).
429
+
430
+ This won't save modifications other than (added tokens and special token mapping) you may have
431
+ applied to the tokenizer after the instantiation (e.g. modifying tokenizer.do_lower_case after creation).
432
+
433
+ This method make sure the full tokenizer can then be re-loaded using the :func:`~transformers.PreTrainedTokenizer.from_pretrained` class method.
434
+ """
435
+ if not os.path.isdir(save_directory):
436
+ logger.error("Saving directory ({}) should be a directory".format(save_directory))
437
+ return
438
+
439
+ special_tokens_map_file = os.path.join(save_directory, SPECIAL_TOKENS_MAP_FILE)
440
+ added_tokens_file = os.path.join(save_directory, ADDED_TOKENS_FILE)
441
+ tokenizer_config_file = os.path.join(save_directory, TOKENIZER_CONFIG_FILE)
442
+
443
+ tokenizer_config = copy.deepcopy(self.init_kwargs)
444
+ tokenizer_config['init_inputs'] = copy.deepcopy(self.init_inputs)
445
+ for file_id in self.vocab_files_names.keys():
446
+ tokenizer_config.pop(file_id, None)
447
+
448
+ with open(tokenizer_config_file, 'w', encoding='utf-8') as f:
449
+ f.write(json.dumps(tokenizer_config, ensure_ascii=False))
450
+
451
+ with open(special_tokens_map_file, 'w', encoding='utf-8') as f:
452
+ f.write(json.dumps(self.special_tokens_map, ensure_ascii=False))
453
+
454
+ with open(added_tokens_file, 'w', encoding='utf-8') as f:
455
+ if self.added_tokens_encoder:
456
+ out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False)
457
+ else:
458
+ out_str = u"{}"
459
+ f.write(out_str)
460
+
461
+ vocab_files = self.save_vocabulary(save_directory)
462
+
463
+ return vocab_files + (special_tokens_map_file, added_tokens_file)
464
+
465
+
466
+ def save_vocabulary(self, save_directory):
467
+ """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
468
+ and special token mappings.
469
+
470
+ Please use :func:`~transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full Tokenizer state if you want to reload it using the :func:`~transformers.PreTrainedTokenizer.from_pretrained` class method.
471
+ """
472
+ raise NotImplementedError
473
+
474
+
475
+ def vocab_size(self):
476
+ """ Size of the base vocabulary (without the added tokens) """
477
+ raise NotImplementedError
478
+
479
+
480
+ def __len__(self):
481
+ """ Size of the full vocabulary with the added tokens """
482
+ return self.vocab_size + len(self.added_tokens_encoder)
483
+
484
+
485
+ def add_tokens(self, new_tokens):
486
+ """
487
+ Add a list of new tokens to the tokenizer class. If the new tokens are not in the
488
+ vocabulary, they are added to it with indices starting from length of the current vocabulary.
489
+
490
+ Args:
491
+ new_tokens: list of string. Each string is a token to add. Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
492
+
493
+ Returns:
494
+ Number of tokens added to the vocabulary.
495
+
496
+ Examples::
497
+
498
+ # Let's see how to increase the vocabulary of Bert model and tokenizer
499
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
500
+ model = BertModel.from_pretrained('bert-base-uncased')
501
+
502
+ num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
503
+ print('We have added', num_added_toks, 'tokens')
504
+ model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
505
+ """
506
+ if not new_tokens:
507
+ return 0
508
+
509
+ to_add_tokens = []
510
+ for token in new_tokens:
511
+ assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode))
512
+ if token != self.unk_token and \
513
+ self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token) and \
514
+ token not in to_add_tokens:
515
+ to_add_tokens.append(token)
516
+ logger.info("Adding %s to the vocabulary", token)
517
+
518
+ added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
519
+ added_tok_decoder = {v:k for k, v in added_tok_encoder.items()}
520
+ self.added_tokens_encoder.update(added_tok_encoder)
521
+ self.added_tokens_decoder.update(added_tok_decoder)
522
+
523
+ return len(to_add_tokens)
524
+
525
+ def num_added_tokens(self, pair=False):
526
+ """
527
+ Returns the number of added tokens when encoding a sequence with special tokens.
528
+
529
+ Note:
530
+ This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this
531
+ inside your training loop.
532
+
533
+ Args:
534
+ pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the
535
+ number of added tokens in the case of a single sequence if set to False.
536
+
537
+ Returns:
538
+ Number of tokens added to sequences
539
+ """
540
+ token_ids_0 = []
541
+ token_ids_1 = []
542
+ return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
543
+
544
+ def add_special_tokens(self, special_tokens_dict):
545
+ """
546
+ Add a dictionary of special tokens (eos, pad, cls...) to the encoder and link them
547
+ to class attributes. If special tokens are NOT in the vocabulary, they are added
548
+ to it (indexed starting from the last index of the current vocabulary).
549
+
550
+ Using `add_special_tokens` will ensure your special tokens can be used in several ways:
551
+
552
+ - special tokens are carefully handled by the tokenizer (they are never split)
553
+ - you can easily refer to special tokens using tokenizer class attributes like `tokenizer.cls_token`. This makes it easy to develop model-agnostic training and fine-tuning scripts.
554
+
555
+ When possible, special tokens are already registered for provided pretrained models (ex: BertTokenizer cls_token is already registered to be '[CLS]' and XLM's one is also registered to be '</s>')
556
+
557
+ Args:
558
+ special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes:
559
+ [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``,
560
+ ``additional_special_tokens``].
561
+
562
+ Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
563
+
564
+ Returns:
565
+ Number of tokens added to the vocabulary.
566
+
567
+ Examples::
568
+
569
+ # Let's see how to add a new classification token to GPT-2
570
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
571
+ model = GPT2Model.from_pretrained('gpt2')
572
+
573
+ special_tokens_dict = {'cls_token': '<CLS>'}
574
+
575
+ num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
576
+ print('We have added', num_added_toks, 'tokens')
577
+ model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
578
+
579
+ assert tokenizer.cls_token == '<CLS>'
580
+ """
581
+ if not special_tokens_dict:
582
+ return 0
583
+
584
+ added_tokens = 0
585
+ for key, value in special_tokens_dict.items():
586
+ assert key in self.SPECIAL_TOKENS_ATTRIBUTES
587
+ if key == 'additional_special_tokens':
588
+ assert isinstance(value, (list, tuple)) and all(isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value)
589
+ added_tokens += self.add_tokens(value)
590
+ else:
591
+ assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode))
592
+ added_tokens += self.add_tokens([value])
593
+ logger.info("Assigning %s to the %s key of the tokenizer", value, key)
594
+ setattr(self, key, value)
595
+
596
+ return added_tokens
597
+
598
+ def tokenize(self, text, **kwargs):
599
+ """ Converts a string in a sequence of tokens (string), using the tokenizer.
600
+ Split in words for word-based vocabulary or sub-words for sub-word-based
601
+ vocabularies (BPE/SentencePieces/WordPieces).
602
+
603
+ Take care of added tokens.
604
+ """
605
+ def split_on_token(tok, text):
606
+ result = []
607
+ split_text = text.split(tok)
608
+ for i, sub_text in enumerate(split_text):
609
+ sub_text = sub_text.strip()
610
+ if i == 0 and not sub_text:
611
+ result += [tok]
612
+ elif i == len(split_text) - 1:
613
+ if sub_text:
614
+ result += [sub_text]
615
+ else:
616
+ pass
617
+ else:
618
+ if sub_text:
619
+ result += [sub_text]
620
+ result += [tok]
621
+ return result
622
+
623
+ def split_on_tokens(tok_list, text):
624
+ if not text:
625
+ return []
626
+ if not tok_list:
627
+ return self._tokenize(text, **kwargs)
628
+
629
+ tokenized_text = []
630
+ text_list = [text]
631
+ for tok in tok_list:
632
+ tokenized_text = []
633
+ for sub_text in text_list:
634
+ if sub_text not in self.added_tokens_encoder \
635
+ and sub_text not in self.all_special_tokens:
636
+ tokenized_text += split_on_token(tok, sub_text)
637
+ else:
638
+ tokenized_text += [sub_text]
639
+ text_list = tokenized_text
640
+
641
+ return sum((self._tokenize(token, **kwargs) if token not \
642
+ in self.added_tokens_encoder and token not in self.all_special_tokens \
643
+ else [token] for token in tokenized_text), [])
644
+
645
+ added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens
646
+ tokenized_text = split_on_tokens(added_tokens, text)
647
+ return tokenized_text
648
+
649
+ def _tokenize(self, text, **kwargs):
650
+ """ Converts a string in a sequence of tokens (string), using the tokenizer.
651
+ Split in words for word-based vocabulary or sub-words for sub-word-based
652
+ vocabularies (BPE/SentencePieces/WordPieces).
653
+
654
+ Do NOT take care of added tokens.
655
+ """
656
+ raise NotImplementedError
657
+
658
+ def convert_tokens_to_ids(self, tokens):
659
+ """ Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id
660
+ (resp. a sequence of ids), using the vocabulary.
661
+ """
662
+ if tokens is None:
663
+ return None
664
+
665
+ if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
666
+ return self._convert_token_to_id_with_added_voc(tokens)
667
+
668
+ ids = []
669
+ for token in tokens:
670
+ ids.append(self._convert_token_to_id_with_added_voc(token))
671
+ if len(ids) > self.max_len:
672
+ logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
673
+ "for this model ({} > {}). Running this sequence through the model will result in "
674
+ "indexing errors".format(len(ids), self.max_len))
675
+ return ids
676
+
677
+ def _convert_token_to_id_with_added_voc(self, token):
678
+ if token is None:
679
+ return None
680
+
681
+ if token in self.added_tokens_encoder:
682
+ return self.added_tokens_encoder[token]
683
+ return self._convert_token_to_id(token)
684
+
685
+ def _convert_token_to_id(self, token):
686
+ raise NotImplementedError
687
+
688
+ def encode(self,
689
+ text,
690
+ text_pair=None,
691
+ add_special_tokens=False,
692
+ max_length=None,
693
+ stride=0,
694
+ truncation_strategy='longest_first',
695
+ return_tensors=None,
696
+ **kwargs):
697
+ """
698
+ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
699
+
700
+ Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
701
+
702
+ Args:
703
+ text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
704
+ the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
705
+ method)
706
+ text_pair: Optional second sequence to be encoded. This can be a string, a list of strings (tokenized
707
+ string using the `tokenize` method) or a list of integers (tokenized string ids using the
708
+ `convert_tokens_to_ids` method)
709
+ add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
710
+ to their model.
711
+ max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
712
+ If there are overflowing tokens, those will be added to the returned dictionary
713
+ stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
714
+ from the main sequence returned. The value of this argument defines the number of additional tokens.
715
+ truncation_strategy: string selected in the following options:
716
+ - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
717
+ starting from the longest one at each token (when there is a pair of input sequences)
718
+ - 'only_first': Only truncate the first sequence
719
+ - 'only_second': Only truncate the second sequence
720
+ - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
721
+ return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
722
+ or PyTorch torch.Tensor instead of a list of python integers.
723
+ **kwargs: passed to the `self.tokenize()` method
724
+ """
725
+ encoded_inputs = self.encode_plus(text,
726
+ text_pair=text_pair,
727
+ max_length=max_length,
728
+ add_special_tokens=add_special_tokens,
729
+ stride=stride,
730
+ truncation_strategy=truncation_strategy,
731
+ return_tensors=return_tensors,
732
+ **kwargs)
733
+
734
+ return encoded_inputs["input_ids"]
735
+
736
+ def encode_plus(self,
737
+ text,
738
+ text_pair=None,
739
+ add_special_tokens=False,
740
+ max_length=None,
741
+ stride=0,
742
+ truncation_strategy='longest_first',
743
+ return_tensors=None,
744
+ **kwargs):
745
+ """
746
+ Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
747
+ the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
748
+
749
+ Args:
750
+ text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
751
+ the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
752
+ method)
753
+ text_pair: Optional second sequence to be encoded. This can be a string, a list of strings (tokenized
754
+ string using the `tokenize` method) or a list of integers (tokenized string ids using the
755
+ `convert_tokens_to_ids` method)
756
+ add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
757
+ to their model.
758
+ max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
759
+ If there are overflowing tokens, those will be added to the returned dictionary
760
+ stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
761
+ from the main sequence returned. The value of this argument defines the number of additional tokens.
762
+ truncation_strategy: string selected in the following options:
763
+ - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
764
+ starting from the longest one at each token (when there is a pair of input sequences)
765
+ - 'only_first': Only truncate the first sequence
766
+ - 'only_second': Only truncate the second sequence
767
+ - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
768
+ return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
769
+ or PyTorch torch.Tensor instead of a list of python integers.
770
+ **kwargs: passed to the `self.tokenize()` method
771
+ """
772
+
773
+ def get_input_ids(text):
774
+ if isinstance(text, six.string_types):
775
+ return self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
776
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], six.string_types):
777
+ return self.convert_tokens_to_ids(text)
778
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
779
+ return text
780
+ else:
781
+ raise ValueError("Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers.")
782
+
783
+ first_ids = get_input_ids(text)
784
+ second_ids = get_input_ids(text_pair) if text_pair is not None else None
785
+
786
+ return self.prepare_for_model(first_ids,
787
+ pair_ids=second_ids,
788
+ max_length=max_length,
789
+ add_special_tokens=add_special_tokens,
790
+ stride=stride,
791
+ truncation_strategy=truncation_strategy,
792
+ return_tensors=return_tensors)
793
+
794
+ def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=False, stride=0,
795
+ truncation_strategy='longest_first', return_tensors=None):
796
+ """
797
+ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
798
+ It adds special tokens, truncates
799
+ sequences if overflowing while taking into account the special tokens and manages a window stride for
800
+ overflowing tokens
801
+
802
+ Args:
803
+ ids: list of tokenized input ids. Can be obtained from a string by chaining the
804
+ `tokenize` and `convert_tokens_to_ids` methods.
805
+ pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the
806
+ `tokenize` and `convert_tokens_to_ids` methods.
807
+ max_length: maximum length of the returned list. Will truncate by taking into account the special tokens.
808
+ add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
809
+ to their model.
810
+ stride: window stride for overflowing tokens. Can be useful for edge effect removal when using sequential
811
+ list of inputs.
812
+ truncation_strategy: string selected in the following options:
813
+ - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
814
+ starting from the longest one at each token (when there is a pair of input sequences)
815
+ - 'only_first': Only truncate the first sequence
816
+ - 'only_second': Only truncate the second sequence
817
+ - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
818
+ return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
819
+ or PyTorch torch.Tensor instead of a list of python integers.
820
+
821
+ Return:
822
+ A Dictionary of shape::
823
+
824
+ {
825
+ input_ids: list[int],
826
+ overflowing_tokens: list[int] if a ``max_length`` is specified, else None
827
+ special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True``
828
+ }
829
+
830
+ With the fields:
831
+ ``input_ids``: list of tokens to be fed to a model
832
+
833
+ ``overflowing_tokens``: list of overflowing tokens if a max length is specified.
834
+
835
+ ``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added
836
+ tokens and 1 specifying sequence tokens.
837
+ """
838
+ pair = bool(pair_ids is not None)
839
+ len_ids = len(ids)
840
+ len_pair_ids = len(pair_ids) if pair else 0
841
+
842
+ encoded_inputs = {}
843
+ total_len = len_ids + len_pair_ids + (self.num_added_tokens(pair=pair) if add_special_tokens else 0)
844
+ if max_length and total_len > max_length:
845
+ ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids,
846
+ num_tokens_to_remove=total_len-max_length,
847
+ truncation_strategy=truncation_strategy,
848
+ stride=stride)
849
+ encoded_inputs["overflowing_tokens"] = overflowing_tokens
850
+ encoded_inputs["num_truncated_tokens"] = total_len - max_length
851
+
852
+ if add_special_tokens:
853
+ sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
854
+ token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
855
+ encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
856
+ else:
857
+ sequence = ids + pair_ids if pair else ids
858
+ token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
859
+
860
+ if return_tensors == 'tf' and is_tf_available():
861
+ sequence = tf.constant([sequence])
862
+ token_type_ids = tf.constant([token_type_ids])
863
+ elif return_tensors == 'pt' and is_torch_available():
864
+ sequence = torch.tensor([sequence])
865
+ token_type_ids = torch.tensor([token_type_ids])
866
+ elif return_tensors is not None:
867
+ logger.warning("Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(return_tensors))
868
+
869
+ encoded_inputs["input_ids"] = sequence
870
+ encoded_inputs["token_type_ids"] = token_type_ids
871
+
872
+ if max_length and len(encoded_inputs["input_ids"]) > max_length:
873
+ encoded_inputs["input_ids"] = encoded_inputs["input_ids"][:max_length]
874
+ encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length]
875
+ encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length]
876
+
877
+ return encoded_inputs
878
+
879
+ def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first', stride=0):
880
+ """Truncates a sequence pair in place to the maximum length.
881
+ truncation_strategy: string selected in the following options:
882
+ - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
883
+ starting from the longest one at each token (when there is a pair of input sequences).
884
+ Overflowing tokens only contains overflow from the first sequence.
885
+ - 'only_first': Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove.
886
+ - 'only_second': Only truncate the second sequence
887
+ - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
888
+ """
889
+ if num_tokens_to_remove <= 0:
890
+ return ids, pair_ids, []
891
+
892
+ if truncation_strategy == 'longest_first':
893
+ overflowing_tokens = []
894
+ for _ in range(num_tokens_to_remove):
895
+ if pair_ids is None or len(ids) > len(pair_ids):
896
+ overflowing_tokens = [ids[-1]] + overflowing_tokens
897
+ ids = ids[:-1]
898
+ else:
899
+ pair_ids = pair_ids[:-1]
900
+ window_len = min(len(ids), stride)
901
+ if window_len > 0:
902
+ overflowing_tokens = ids[-window_len:] + overflowing_tokens
903
+ elif truncation_strategy == 'only_first':
904
+ assert len(ids) > num_tokens_to_remove
905
+ window_len = min(len(ids), stride + num_tokens_to_remove)
906
+ overflowing_tokens = ids[-window_len:]
907
+ ids = ids[:-num_tokens_to_remove]
908
+ elif truncation_strategy == 'only_second':
909
+ assert pair_ids is not None and len(pair_ids) > num_tokens_to_remove
910
+ window_len = min(len(pair_ids), stride + num_tokens_to_remove)
911
+ overflowing_tokens = pair_ids[-window_len:]
912
+ pair_ids = pair_ids[:-num_tokens_to_remove]
913
+ elif truncation_strategy == 'do_not_truncate':
914
+ raise ValueError("Input sequence are too long for max_length. Please select a truncation strategy.")
915
+ else:
916
+ raise ValueError("Truncation_strategy should be selected in ['longest_first', 'only_first', 'only_second', 'do_not_truncate']")
917
+ return (ids, pair_ids, overflowing_tokens)
918
+
919
+ def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None):
920
+ logger.warning("This tokenizer does not make use of special tokens.")
921
+ if token_ids_1 is None:
922
+ return len(token_ids_0) * [0]
923
+ return [0] * len(token_ids_0) + [1] * len(token_ids_1)
924
+
925
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
926
+ """
927
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
928
+ by concatenating and adding special tokens.
929
+ A RoBERTa sequence has the following format:
930
+ single sequence: <s> X </s>
931
+ pair of sequences: <s> A </s></s> B </s>
932
+ """
933
+ logger.warning("This tokenizer does not make use of special tokens. Input is returned with no modification.")
934
+ if token_ids_1 is None:
935
+ return token_ids_0
936
+ return token_ids_0 + token_ids_1
937
+
938
+ def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
939
+ """
940
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
941
+ special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
942
+
943
+ Args:
944
+ token_ids_0: list of ids (must not contain special tokens)
945
+ token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
946
+ for sequence pairs
947
+ already_has_special_tokens: (default False) Set to True if the token list is already formated with
948
+ special tokens for the model
949
+
950
+ Returns:
951
+ A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
952
+ """
953
+ return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
954
+
955
+ def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
956
+ """ Converts a single index or a sequence of indices (integers) in a token "
957
+ (resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.
958
+
959
+ Args:
960
+ skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
961
+ """
962
+ if isinstance(ids, int):
963
+ if ids in self.added_tokens_decoder:
964
+ return self.added_tokens_decoder[ids]
965
+ else:
966
+ return self._convert_id_to_token(ids)
967
+ tokens = []
968
+ for index in ids:
969
+ if skip_special_tokens and index in self.all_special_ids:
970
+ continue
971
+ if index in self.added_tokens_decoder:
972
+ tokens.append(self.added_tokens_decoder[index])
973
+ else:
974
+ tokens.append(self._convert_id_to_token(index))
975
+ return tokens
976
+
977
+ def _convert_id_to_token(self, index):
978
+ raise NotImplementedError
979
+
980
+ def convert_tokens_to_string(self, tokens):
981
+ """ Converts a sequence of tokens (string) in a single string.
982
+ The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
983
+ but we often want to remove sub-word tokenization artifacts at the same time.
984
+ """
985
+ return ' '.join(self.convert_ids_to_tokens(tokens))
986
+
987
+ def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
988
+ """
989
+ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
990
+ with options to remove special tokens and clean up tokenization spaces.
991
+ Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
992
+
993
+ Args:
994
+ token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods.
995
+ skip_special_tokens: if set to True, will replace special tokens.
996
+ clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
997
+ """
998
+ filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
999
+
1000
+ # To avoid mixing byte-level and unicode for byte-level BPT
1001
+ # we need to build string separatly for added tokens and byte-level tokens
1002
+ # cf. https://github.com/huggingface/transformers/issues/1133
1003
+ sub_texts = []
1004
+ current_sub_text = []
1005
+ for token in filtered_tokens:
1006
+ if skip_special_tokens and token in self.all_special_ids:
1007
+ continue
1008
+ if token in self.added_tokens_encoder:
1009
+ if current_sub_text:
1010
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
1011
+ current_sub_text = []
1012
+ sub_texts.append(" " + token)
1013
+ else:
1014
+ current_sub_text.append(token)
1015
+ if current_sub_text:
1016
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
1017
+ text = ''.join(sub_texts)
1018
+
1019
+ if clean_up_tokenization_spaces:
1020
+ clean_text = self.clean_up_tokenization(text)
1021
+ return clean_text
1022
+ else:
1023
+ return text
1024
+
1025
+ @property
1026
+ def special_tokens_map(self):
1027
+ """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their
1028
+ values ('<unk>', '<cls>'...)
1029
+ """
1030
+ set_attr = {}
1031
+ for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
1032
+ attr_value = getattr(self, "_" + attr)
1033
+ if attr_value:
1034
+ set_attr[attr] = attr_value
1035
+ return set_attr
1036
+
1037
+ @property
1038
+ def all_special_tokens(self):
1039
+ """ List all the special tokens ('<unk>', '<cls>'...) mapped to class attributes
1040
+ (cls_token, unk_token...).
1041
+ """
1042
+ all_toks = []
1043
+ set_attr = self.special_tokens_map
1044
+ for attr_value in set_attr.values():
1045
+ all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value])
1046
+ all_toks = list(set(all_toks))
1047
+ return all_toks
1048
+
1049
+ @property
1050
+ def all_special_ids(self):
1051
+ """ List the vocabulary indices of the special tokens ('<unk>', '<cls>'...) mapped to
1052
+ class attributes (cls_token, unk_token...).
1053
+ """
1054
+ all_toks = self.all_special_tokens
1055
+ all_ids = list(self._convert_token_to_id(t) for t in all_toks)
1056
+ return all_ids
1057
+
1058
+ @staticmethod
1059
+ def clean_up_tokenization(out_string):
1060
+ """ Clean up a list of simple English tokenization artifacts like spaces before punctuations and abreviated forms.
1061
+ """
1062
+ out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
1063
+ ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
1064
+ ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
1065
+ return out_string
test.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model.tokenization_albert import FullTokenizer
2
+ from model.modeling_albert import AlbertModel
3
+ import torch
4
+
5
+ content = 'ལས་ཁུངས་ཀྱི་ཏང་འཛུགས་སྐྱོང་གི་སྤུས་ཚད་ཕྱོགས་ཡོངས་ནས་མཐོར་གཏོང་བཅས་བྱེད་པའི་བྱེད་ཐབས་གལ་ཆེན་ཞིག་ཡིན་ལ།'
6
+ tokenizer = FullTokenizer(vocab_file='tibetan-albert-syllable-base/vocab.txt', do_lower_case=False)
7
+ token = content.split('་')
8
+ print(token)
9
+ token_ids = tokenizer.convert_tokens_to_ids(token)
10
+ print(token_ids)
11
+ token_ids = torch.LongTensor([token_ids])
12
+
13
+ albert_model = AlbertModel.from_pretrained('tibetan-albert-syllable-base')
14
+
15
+ output = albert_model(input_ids=token_ids)
16
+ print(output)
tibetan-albert-syllable-base/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_probs_dropout_prob": 0.0,
3
+ "directionality": "bidi",
4
+ "embedding_size": 128,
5
+ "finetuning_task": null,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.0,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "inner_group_num": 1,
11
+ "intermediate_size": 3072,
12
+ "layer_norm_eps": 1e-12,
13
+ "ln_type": "postln",
14
+ "max_position_embeddings": 512,
15
+ "num_attention_heads": 12,
16
+ "num_hidden_groups": 1,
17
+ "num_hidden_layers": 12,
18
+ "num_labels": 2,
19
+ "output_attentions": false,
20
+ "output_hidden_states": true,
21
+ "pooler_fc_size": 768,
22
+ "pooler_num_attention_heads": 12,
23
+ "pooler_num_fc_layers": 3,
24
+ "pooler_size_per_head": 128,
25
+ "pooler_type": "first_token_transform",
26
+ "pruned_heads": {},
27
+ "share_type": "all",
28
+ "torchscript": false,
29
+ "type_vocab_size": 2,
30
+ "vocab_size": 18907
31
+ }
tibetan-albert-syllable-base/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:417da4c35036bdf0e172665a3803074e0308e23107c95c184d720c75527dfbb2
3
+ size 83071196
tibetan-albert-syllable-base/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a79ecf11416b44059887940e9c1b52410475f76b55c09f036808e0f157dedd7
3
+ size 41539500
tibetan-albert-syllable-base/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b0bccd161964d4d2eb645235dcc6becc153ddc286f1064a659792035b02d56a
3
+ size 905
tibetan-albert-syllable-base/vocab.txt ADDED
The diff for this file is too large to render. See raw diff