dianecy commited on
Commit
8d82201
·
verified ·
1 Parent(s): 0b32e3c

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +12 -0
  2. LAVT-RIS/__pycache__/args.cpython-39.pyc +0 -0
  3. LAVT-RIS/__pycache__/train_refzom.cpython-39.pyc +0 -0
  4. LAVT-RIS/__pycache__/transforms.cpython-39.pyc +0 -0
  5. LAVT-RIS/__pycache__/utils.cpython-39.pyc +0 -0
  6. LAVT-RIS/angle_vis.ipynb +0 -0
  7. LAVT-RIS/args.py +62 -0
  8. LAVT-RIS/bert/__pycache__/activations.cpython-37.pyc +0 -0
  9. LAVT-RIS/bert/__pycache__/activations.cpython-39.pyc +0 -0
  10. LAVT-RIS/bert/__pycache__/configuration_bert.cpython-37.pyc +0 -0
  11. LAVT-RIS/bert/__pycache__/configuration_bert.cpython-39.pyc +0 -0
  12. LAVT-RIS/bert/__pycache__/configuration_utils.cpython-37.pyc +0 -0
  13. LAVT-RIS/bert/__pycache__/configuration_utils.cpython-39.pyc +0 -0
  14. LAVT-RIS/bert/__pycache__/file_utils.cpython-37.pyc +0 -0
  15. LAVT-RIS/bert/__pycache__/file_utils.cpython-39.pyc +0 -0
  16. LAVT-RIS/bert/__pycache__/generation_utils.cpython-37.pyc +0 -0
  17. LAVT-RIS/bert/__pycache__/generation_utils.cpython-39.pyc +0 -0
  18. LAVT-RIS/bert/__pycache__/modeling_bert.cpython-37.pyc +0 -0
  19. LAVT-RIS/bert/__pycache__/modeling_bert.cpython-39.pyc +0 -0
  20. LAVT-RIS/bert/__pycache__/modeling_utils.cpython-37.pyc +0 -0
  21. LAVT-RIS/bert/__pycache__/modeling_utils.cpython-39.pyc +0 -0
  22. LAVT-RIS/bert/__pycache__/tokenization_bert.cpython-39.pyc +0 -0
  23. LAVT-RIS/bert/__pycache__/tokenization_utils.cpython-39.pyc +0 -0
  24. LAVT-RIS/bert/__pycache__/tokenization_utils_base.cpython-39.pyc +0 -0
  25. LAVT-RIS/bert/activations.py +56 -0
  26. LAVT-RIS/bert/configuration_bert.py +143 -0
  27. LAVT-RIS/bert/configuration_utils.py +408 -0
  28. LAVT-RIS/bert/file_utils.py +808 -0
  29. LAVT-RIS/bert/generation_utils.py +993 -0
  30. LAVT-RIS/bert/modeling_bert.py +1569 -0
  31. LAVT-RIS/bert/modeling_utils.py +1268 -0
  32. LAVT-RIS/bert/tokenization_bert.py +545 -0
  33. LAVT-RIS/bert/tokenization_utils.py +723 -0
  34. LAVT-RIS/bert/tokenization_utils_base.py +0 -0
  35. LAVT-RIS/data.ipynb +0 -0
  36. LAVT-RIS/data/__pycache__/dataset_refer_bert.cpython-39.pyc +0 -0
  37. LAVT-RIS/data/__pycache__/dataset_refer_bert_mostat.cpython-39.pyc +0 -0
  38. LAVT-RIS/data/__pycache__/dataset_refer_bert_rev.cpython-39.pyc +0 -0
  39. LAVT-RIS/data/__pycache__/dataset_refer_zom.cpython-39.pyc +0 -0
  40. LAVT-RIS/data/dataset_refer_bert.py +228 -0
  41. LAVT-RIS/data/dataset_refer_bert_mostat.py +136 -0
  42. LAVT-RIS/data/dataset_refer_bert_rev.py +246 -0
  43. LAVT-RIS/data/dataset_refer_zom.py +296 -0
  44. LAVT-RIS/datagen.txt +49 -0
  45. LAVT-RIS/demo_inference.py +118 -0
  46. LAVT-RIS/donghwa/args.py +212 -0
  47. LAVT-RIS/donghwa/config/__pycache__/utils.cpython-37.pyc +0 -0
  48. LAVT-RIS/donghwa/config/n_obj/n_12.yaml +1 -0
  49. LAVT-RIS/donghwa/config/n_obj/n_34.yaml +1 -0
  50. LAVT-RIS/donghwa/config/n_obj/n_56.yaml +1 -0
.gitattributes CHANGED
@@ -34,3 +34,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  RIS-DMMI/refer/evaluation/tokenizer/stanford-corenlp-3.4.1.jar filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  RIS-DMMI/refer/evaluation/tokenizer/stanford-corenlp-3.4.1.jar filter=lfs diff=lfs merge=lfs -text
37
+ LAVT-RIS/logs/old/gref_m10_mg12_tmp007_2gpu_bs16_ang.log filter=lfs diff=lfs merge=lfs -text
38
+ LAVT-RIS/refer/data/ref-zom/instances.json filter=lfs diff=lfs merge=lfs -text
39
+ LAVT-RIS/refer/data/ref-zom/refs(final).p filter=lfs diff=lfs merge=lfs -text
40
+ LAVT-RIS/refer/data/refcoco/instances.json filter=lfs diff=lfs merge=lfs -text
41
+ LAVT-RIS/refer/data/refcoco/refs(google).p filter=lfs diff=lfs merge=lfs -text
42
+ LAVT-RIS/refer/data/refcoco/refs(unc).p filter=lfs diff=lfs merge=lfs -text
43
+ LAVT-RIS/refer/data/refcoco+/instances.json filter=lfs diff=lfs merge=lfs -text
44
+ LAVT-RIS/refer/data/refcoco+/refs(unc).p filter=lfs diff=lfs merge=lfs -text
45
+ LAVT-RIS/refer/data/refcocog/instances.json filter=lfs diff=lfs merge=lfs -text
46
+ LAVT-RIS/refer/data/refcocog/refs(google).p filter=lfs diff=lfs merge=lfs -text
47
+ LAVT-RIS/refer/data/refcocog/refs(umd).p filter=lfs diff=lfs merge=lfs -text
48
+ LAVT-RIS/refer/evaluation/tokenizer/stanford-corenlp-3.4.1.jar filter=lfs diff=lfs merge=lfs -text
LAVT-RIS/__pycache__/args.cpython-39.pyc ADDED
Binary file (3.23 kB). View file
 
LAVT-RIS/__pycache__/train_refzom.cpython-39.pyc ADDED
Binary file (10.1 kB). View file
 
LAVT-RIS/__pycache__/transforms.cpython-39.pyc ADDED
Binary file (4.93 kB). View file
 
LAVT-RIS/__pycache__/utils.cpython-39.pyc ADDED
Binary file (7.27 kB). View file
 
LAVT-RIS/angle_vis.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
LAVT-RIS/args.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def get_parser():
5
+ parser = argparse.ArgumentParser(description='LAVT training and testing')
6
+ parser.add_argument('--amsgrad', action='store_true',
7
+ help='if true, set amsgrad to True in an Adam or AdamW optimizer.')
8
+ parser.add_argument('-b', '--batch-size', default=8, type=int)
9
+ parser.add_argument('--bert_tokenizer', default='bert-base-uncased', help='BERT tokenizer')
10
+ parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights')
11
+ parser.add_argument('--dataset', default='refcoco', help='refcoco, refcoco+, or refcocog')
12
+ parser.add_argument('--ddp_trained_weights', action='store_true',
13
+ help='Only needs specified when testing,'
14
+ 'whether the weights to be loaded are from a DDP-trained model')
15
+ parser.add_argument('--device', default='cuda:0', help='device') # only used when testing on a single machine
16
+ parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run')
17
+ parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs')
18
+ parser.add_argument('--img_size', default=480, type=int, help='input image size')
19
+ # parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel')
20
+ parser.add_argument('--lr', default=0.00005, type=float, help='the initial learning rate')
21
+ parser.add_argument('--mha', default='', help='If specified, should be in the format of a-b-c-d, e.g., 4-4-4-4,'
22
+ 'where a, b, c, and d refer to the numbers of heads in stage-1,'
23
+ 'stage-2, stage-3, and stage-4 PWAMs')
24
+ parser.add_argument('--model', default='lavt', help='model: lavt, lavt_one')
25
+ parser.add_argument('--model_id', default='lavt', help='name to identify the model')
26
+ parser.add_argument('--output-dir', default='./checkpoints/', help='path where to save checkpoint weights')
27
+ parser.add_argument('--pin_mem', action='store_true',
28
+ help='If true, pin memory when using the data loader.')
29
+ parser.add_argument('--pretrained_swin_weights', default='',
30
+ help='path to pre-trained Swin backbone weights')
31
+ parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
32
+ parser.add_argument('--refer_data_root', default='./refer/data/', help='REFER dataset root directory')
33
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
34
+ parser.add_argument('--split', default='test', help='only used when testing')
35
+ parser.add_argument('--splitBy', default='unc', help='change to umd or google when the dataset is G-Ref (RefCOCOg)')
36
+ parser.add_argument('--swin_type', default='base',
37
+ help='tiny, small, base, or large variants of the Swin Transformer')
38
+ parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay',
39
+ dest='weight_decay')
40
+ parser.add_argument('--window12', action='store_true',
41
+ help='only needs specified when testing,'
42
+ 'when training, window size is inferred from pre-trained weights file name'
43
+ '(containing \'window12\'). Initialize Swin with window size 12 instead of the default 7.')
44
+ parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers')
45
+
46
+ # metric loss related ones
47
+ parser.add_argument('--metric_learning', action='store_true',help='whether to use metric learning')
48
+ parser.add_argument('--metric_loss_weight', default=0.1, type=float, help='weight for metric loss')
49
+ parser.add_argument('--metric_mode', default='hardpos_rev3', help='test options..')
50
+ parser.add_argument('--exclude_multiobj', action='store_true', help='exclude multi-object images')
51
+ parser.add_argument('--hn_prob', default=0.0, type=float, help='hard negative probability')
52
+ parser.add_argument('--hp_selection', default='naive', help='test options..')
53
+ parser.add_argument('--margin_value', default=10, type=float, help='weight for metric loss')
54
+ parser.add_argument('--temperature', default=0.05, type=float, help='test options..')
55
+ parser.add_argument('--addzero', action='store_true', help='test options..')
56
+
57
+ return parser
58
+
59
+
60
+ if __name__ == "__main__":
61
+ parser = get_parser()
62
+ args_dict = parser.parse_args()
LAVT-RIS/bert/__pycache__/activations.cpython-37.pyc ADDED
Binary file (1.96 kB). View file
 
LAVT-RIS/bert/__pycache__/activations.cpython-39.pyc ADDED
Binary file (1.94 kB). View file
 
LAVT-RIS/bert/__pycache__/configuration_bert.cpython-37.pyc ADDED
Binary file (7.87 kB). View file
 
LAVT-RIS/bert/__pycache__/configuration_bert.cpython-39.pyc ADDED
Binary file (7.88 kB). View file
 
LAVT-RIS/bert/__pycache__/configuration_utils.cpython-37.pyc ADDED
Binary file (16.2 kB). View file
 
LAVT-RIS/bert/__pycache__/configuration_utils.cpython-39.pyc ADDED
Binary file (16.3 kB). View file
 
LAVT-RIS/bert/__pycache__/file_utils.cpython-37.pyc ADDED
Binary file (24.4 kB). View file
 
LAVT-RIS/bert/__pycache__/file_utils.cpython-39.pyc ADDED
Binary file (24.7 kB). View file
 
LAVT-RIS/bert/__pycache__/generation_utils.cpython-37.pyc ADDED
Binary file (27.9 kB). View file
 
LAVT-RIS/bert/__pycache__/generation_utils.cpython-39.pyc ADDED
Binary file (28 kB). View file
 
LAVT-RIS/bert/__pycache__/modeling_bert.cpython-37.pyc ADDED
Binary file (56.4 kB). View file
 
LAVT-RIS/bert/__pycache__/modeling_bert.cpython-39.pyc ADDED
Binary file (55.2 kB). View file
 
LAVT-RIS/bert/__pycache__/modeling_utils.cpython-37.pyc ADDED
Binary file (48 kB). View file
 
LAVT-RIS/bert/__pycache__/modeling_utils.cpython-39.pyc ADDED
Binary file (48 kB). View file
 
LAVT-RIS/bert/__pycache__/tokenization_bert.cpython-39.pyc ADDED
Binary file (19.3 kB). View file
 
LAVT-RIS/bert/__pycache__/tokenization_utils.cpython-39.pyc ADDED
Binary file (24.9 kB). View file
 
LAVT-RIS/bert/__pycache__/tokenization_utils_base.cpython-39.pyc ADDED
Binary file (82.4 kB). View file
 
LAVT-RIS/bert/activations.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def swish(x):
12
+ return x * torch.sigmoid(x)
13
+
14
+
15
+ def _gelu_python(x):
16
+ """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
17
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
18
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
19
+ This is now written in C in torch.nn.functional
20
+ Also see https://arxiv.org/abs/1606.08415
21
+ """
22
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
23
+
24
+
25
+ def gelu_new(x):
26
+ """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
27
+ Also see https://arxiv.org/abs/1606.08415
28
+ """
29
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
30
+
31
+
32
+ if torch.__version__ < "1.4.0":
33
+ gelu = _gelu_python
34
+ else:
35
+ gelu = F.gelu
36
+
37
+
38
+ def gelu_fast(x):
39
+ return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
40
+
41
+
42
+ ACT2FN = {
43
+ "relu": F.relu,
44
+ "swish": swish,
45
+ "gelu": gelu,
46
+ "tanh": torch.tanh,
47
+ "gelu_new": gelu_new,
48
+ "gelu_fast": gelu_fast,
49
+ }
50
+
51
+
52
+ def get_activation(activation_string):
53
+ if activation_string in ACT2FN:
54
+ return ACT2FN[activation_string]
55
+ else:
56
+ raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
LAVT-RIS/bert/configuration_bert.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ BERT model configuration """
17
+
18
+
19
+ import logging
20
+
21
+ from .configuration_utils import PretrainedConfig
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
28
+ "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
29
+ "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
30
+ "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
31
+ "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
32
+ "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
33
+ "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
34
+ "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
35
+ "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
36
+ "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
37
+ "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
38
+ "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
39
+ "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
40
+ "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
41
+ "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
42
+ "cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json",
43
+ "cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json",
44
+ "cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json",
45
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json",
46
+ "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
47
+ "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
48
+ "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json",
49
+ # See all BERT models at https://huggingface.co/models?filter=bert
50
+ }
51
+
52
+
53
+ class BertConfig(PretrainedConfig):
54
+ r"""
55
+ This is the configuration class to store the configuration of a :class:`~transformers.BertModel`.
56
+ It is used to instantiate an BERT model according to the specified arguments, defining the model
57
+ architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
58
+ the BERT `bert-base-uncased <https://huggingface.co/bert-base-uncased>`__ architecture.
59
+
60
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
61
+ to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
62
+ for more information.
63
+
64
+
65
+ Args:
66
+ vocab_size (:obj:`int`, optional, defaults to 30522):
67
+ Vocabulary size of the BERT model. Defines the different tokens that
68
+ can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`.
69
+ hidden_size (:obj:`int`, optional, defaults to 768):
70
+ Dimensionality of the encoder layers and the pooler layer.
71
+ num_hidden_layers (:obj:`int`, optional, defaults to 12):
72
+ Number of hidden layers in the Transformer encoder.
73
+ num_attention_heads (:obj:`int`, optional, defaults to 12):
74
+ Number of attention heads for each attention layer in the Transformer encoder.
75
+ intermediate_size (:obj:`int`, optional, defaults to 3072):
76
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
77
+ hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
78
+ The non-linear activation function (function or string) in the encoder and pooler.
79
+ If string, "gelu", "relu", "swish" and "gelu_new" are supported.
80
+ hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1):
81
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
82
+ attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
83
+ The dropout ratio for the attention probabilities.
84
+ max_position_embeddings (:obj:`int`, optional, defaults to 512):
85
+ The maximum sequence length that this model might ever be used with.
86
+ Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
87
+ type_vocab_size (:obj:`int`, optional, defaults to 2):
88
+ The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
89
+ initializer_range (:obj:`float`, optional, defaults to 0.02):
90
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
91
+ layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
92
+ The epsilon used by the layer normalization layers.
93
+ gradient_checkpointing (:obj:`bool`, optional, defaults to False):
94
+ If True, use gradient checkpointing to save memory at the expense of slower backward pass.
95
+
96
+ Example::
97
+
98
+ >>> from transformers import BertModel, BertConfig
99
+
100
+ >>> # Initializing a BERT bert-base-uncased style configuration
101
+ >>> configuration = BertConfig()
102
+
103
+ >>> # Initializing a model from the bert-base-uncased style configuration
104
+ >>> model = BertModel(configuration)
105
+
106
+ >>> # Accessing the model configuration
107
+ >>> configuration = model.config
108
+ """
109
+ model_type = "bert"
110
+
111
+ def __init__(
112
+ self,
113
+ vocab_size=30522,
114
+ hidden_size=768,
115
+ num_hidden_layers=12,
116
+ num_attention_heads=12,
117
+ intermediate_size=3072,
118
+ hidden_act="gelu",
119
+ hidden_dropout_prob=0.1,
120
+ attention_probs_dropout_prob=0.1,
121
+ max_position_embeddings=512,
122
+ type_vocab_size=2,
123
+ initializer_range=0.02,
124
+ layer_norm_eps=1e-12,
125
+ pad_token_id=0,
126
+ gradient_checkpointing=False,
127
+ **kwargs
128
+ ):
129
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
130
+
131
+ self.vocab_size = vocab_size
132
+ self.hidden_size = hidden_size
133
+ self.num_hidden_layers = num_hidden_layers
134
+ self.num_attention_heads = num_attention_heads
135
+ self.hidden_act = hidden_act
136
+ self.intermediate_size = intermediate_size
137
+ self.hidden_dropout_prob = hidden_dropout_prob
138
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
139
+ self.max_position_embeddings = max_position_embeddings
140
+ self.type_vocab_size = type_vocab_size
141
+ self.initializer_range = initializer_range
142
+ self.layer_norm_eps = layer_norm_eps
143
+ self.gradient_checkpointing = gradient_checkpointing
LAVT-RIS/bert/configuration_utils.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ Configuration base class and utilities."""
17
+
18
+
19
+ import copy
20
+ import json
21
+ import logging
22
+ import os
23
+ from typing import Dict, Tuple
24
+
25
+ from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class PretrainedConfig(object):
32
+ r""" Base class for all configuration classes.
33
+ Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
34
+
35
+ Note:
36
+ A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights.
37
+ It only affects the model's configuration.
38
+
39
+ Class attributes (overridden by derived classes):
40
+ - ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`.
41
+
42
+ Args:
43
+ finetuning_task (:obj:`string` or :obj:`None`, `optional`, defaults to :obj:`None`):
44
+ Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
45
+ num_labels (:obj:`int`, `optional`, defaults to `2`):
46
+ Number of classes to use when the model is a classification model (sequences/tokens)
47
+ output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`False`):
48
+ Should the model returns all hidden-states.
49
+ output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
50
+ Should the model returns all attentions.
51
+ torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
52
+ Is the model used with Torchscript (for PyTorch models).
53
+ """
54
+ model_type: str = ""
55
+
56
+ def __init__(self, **kwargs):
57
+ # Attributes with defaults
58
+ self.output_hidden_states = kwargs.pop("output_hidden_states", False)
59
+ self.output_attentions = kwargs.pop("output_attentions", False)
60
+ self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
61
+ self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
62
+ self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
63
+ self.pruned_heads = kwargs.pop("pruned_heads", {})
64
+
65
+ # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
66
+ self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
67
+ self.is_decoder = kwargs.pop("is_decoder", False)
68
+
69
+ # Parameters for sequence generation
70
+ self.max_length = kwargs.pop("max_length", 20)
71
+ self.min_length = kwargs.pop("min_length", 0)
72
+ self.do_sample = kwargs.pop("do_sample", False)
73
+ self.early_stopping = kwargs.pop("early_stopping", False)
74
+ self.num_beams = kwargs.pop("num_beams", 1)
75
+ self.temperature = kwargs.pop("temperature", 1.0)
76
+ self.top_k = kwargs.pop("top_k", 50)
77
+ self.top_p = kwargs.pop("top_p", 1.0)
78
+ self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
79
+ self.length_penalty = kwargs.pop("length_penalty", 1.0)
80
+ self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
81
+ self.bad_words_ids = kwargs.pop("bad_words_ids", None)
82
+ self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
83
+
84
+ # Fine-tuning task arguments
85
+ self.architectures = kwargs.pop("architectures", None)
86
+ self.finetuning_task = kwargs.pop("finetuning_task", None)
87
+ self.id2label = kwargs.pop("id2label", None)
88
+ self.label2id = kwargs.pop("label2id", None)
89
+ if self.id2label is not None:
90
+ kwargs.pop("num_labels", None)
91
+ self.id2label = dict((int(key), value) for key, value in self.id2label.items())
92
+ # Keys are always strings in JSON so convert ids to int here.
93
+ else:
94
+ self.num_labels = kwargs.pop("num_labels", 2)
95
+
96
+ # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
97
+ self.prefix = kwargs.pop("prefix", None)
98
+ self.bos_token_id = kwargs.pop("bos_token_id", None)
99
+ self.pad_token_id = kwargs.pop("pad_token_id", None)
100
+ self.eos_token_id = kwargs.pop("eos_token_id", None)
101
+ self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
102
+
103
+ # task specific arguments
104
+ self.task_specific_params = kwargs.pop("task_specific_params", None)
105
+
106
+ # TPU arguments
107
+ self.xla_device = kwargs.pop("xla_device", None)
108
+
109
+ # Additional attributes without default values
110
+ for key, value in kwargs.items():
111
+ try:
112
+ setattr(self, key, value)
113
+ except AttributeError as err:
114
+ logger.error("Can't set {} with value {} for {}".format(key, value, self))
115
+ raise err
116
+
117
+ @property
118
+ def num_labels(self):
119
+ return len(self.id2label)
120
+
121
+ @num_labels.setter
122
+ def num_labels(self, num_labels):
123
+ self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
124
+ self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
125
+
126
+ def save_pretrained(self, save_directory):
127
+ """
128
+ Save a configuration object to the directory `save_directory`, so that it
129
+ can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
130
+
131
+ Args:
132
+ save_directory (:obj:`string`):
133
+ Directory where the configuration JSON file will be saved.
134
+ """
135
+ if os.path.isfile(save_directory):
136
+ raise AssertionError("Provided path ({}) should be a directory, not a file".format(save_directory))
137
+ os.makedirs(save_directory, exist_ok=True)
138
+ # If we save using the predefined names, we can load using `from_pretrained`
139
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
140
+
141
+ self.to_json_file(output_config_file, use_diff=True)
142
+ logger.info("Configuration saved in {}".format(output_config_file))
143
+
144
+ @classmethod
145
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig":
146
+ r"""
147
+
148
+ Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.
149
+
150
+ Args:
151
+ pretrained_model_name_or_path (:obj:`string`):
152
+ either:
153
+ - a string with the `shortcut name` of a pre-trained model configuration to load from cache or
154
+ download, e.g.: ``bert-base-uncased``.
155
+ - a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to
156
+ our S3, e.g.: ``dbmdz/bert-base-german-cased``.
157
+ - a path to a `directory` containing a configuration file saved using the
158
+ :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
159
+ - a path or url to a saved configuration JSON `file`, e.g.:
160
+ ``./my_model_directory/configuration.json``.
161
+ cache_dir (:obj:`string`, `optional`):
162
+ Path to a directory in which a downloaded pre-trained model
163
+ configuration should be cached if the standard cache should not be used.
164
+ kwargs (:obj:`Dict[str, any]`, `optional`):
165
+ The values in kwargs of any keys which are configuration attributes will be used to override the loaded
166
+ values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is
167
+ controlled by the `return_unused_kwargs` keyword parameter.
168
+ force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
169
+ Force to (re-)download the model weights and configuration files and override the cached versions if they exist.
170
+ resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
171
+ Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
172
+ proxies (:obj:`Dict`, `optional`):
173
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.:
174
+ :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.`
175
+ The proxies are used on each request.
176
+ return_unused_kwargs: (`optional`) bool:
177
+ If False, then this function returns just the final configuration object.
178
+ If True, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a
179
+ dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part
180
+ of kwargs which has not been used to update `config` and is otherwise ignored.
181
+
182
+ Returns:
183
+ :class:`PretrainedConfig`: An instance of a configuration object
184
+
185
+ Examples::
186
+
187
+ # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
188
+ # derived class: BertConfig
189
+ config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
190
+ config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
191
+ config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
192
+ config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
193
+ assert config.output_attention == True
194
+ config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
195
+ foo=False, return_unused_kwargs=True)
196
+ assert config.output_attention == True
197
+ assert unused_kwargs == {'foo': False}
198
+
199
+ """
200
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
201
+ return cls.from_dict(config_dict, **kwargs)
202
+
203
+ @classmethod
204
+ def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict, Dict]:
205
+ """
206
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used
207
+ for instantiating a Config using `from_dict`.
208
+
209
+ Parameters:
210
+ pretrained_model_name_or_path (:obj:`string`):
211
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
212
+
213
+ Returns:
214
+ :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object.
215
+
216
+ """
217
+ cache_dir = kwargs.pop("cache_dir", None)
218
+ force_download = kwargs.pop("force_download", False)
219
+ resume_download = kwargs.pop("resume_download", False)
220
+ proxies = kwargs.pop("proxies", None)
221
+ local_files_only = kwargs.pop("local_files_only", False)
222
+
223
+ if os.path.isdir(pretrained_model_name_or_path):
224
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
225
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
226
+ config_file = pretrained_model_name_or_path
227
+ else:
228
+ config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)
229
+
230
+ try:
231
+ # Load from URL or cache if already cached
232
+ resolved_config_file = cached_path(
233
+ config_file,
234
+ cache_dir=cache_dir,
235
+ force_download=force_download,
236
+ proxies=proxies,
237
+ resume_download=resume_download,
238
+ local_files_only=local_files_only,
239
+ )
240
+ # Load config dict
241
+ if resolved_config_file is None:
242
+ raise EnvironmentError
243
+ config_dict = cls._dict_from_json_file(resolved_config_file)
244
+
245
+ except EnvironmentError:
246
+ msg = (
247
+ f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
248
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
249
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
250
+ )
251
+ raise EnvironmentError(msg)
252
+
253
+ except json.JSONDecodeError:
254
+ msg = (
255
+ "Couldn't reach server at '{}' to download configuration file or "
256
+ "configuration file is not a valid JSON file. "
257
+ "Please check network or file content here: {}.".format(config_file, resolved_config_file)
258
+ )
259
+ raise EnvironmentError(msg)
260
+
261
+ if resolved_config_file == config_file:
262
+ logger.info("loading configuration file {}".format(config_file))
263
+ else:
264
+ logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
265
+
266
+ return config_dict, kwargs
267
+
268
+ @classmethod
269
+ def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig":
270
+ """
271
+ Constructs a `Config` from a Python dictionary of parameters.
272
+
273
+ Args:
274
+ config_dict (:obj:`Dict[str, any]`):
275
+ Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved
276
+ from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict`
277
+ method.
278
+ kwargs (:obj:`Dict[str, any]`):
279
+ Additional parameters from which to initialize the configuration object.
280
+
281
+ Returns:
282
+ :class:`PretrainedConfig`: An instance of a configuration object
283
+ """
284
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
285
+
286
+ config = cls(**config_dict)
287
+
288
+ if hasattr(config, "pruned_heads"):
289
+ config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
290
+
291
+ # Update config with kwargs if needed
292
+ to_remove = []
293
+ for key, value in kwargs.items():
294
+ if hasattr(config, key):
295
+ setattr(config, key, value)
296
+ to_remove.append(key)
297
+ for key in to_remove:
298
+ kwargs.pop(key, None)
299
+
300
+ logger.info("Model config %s", str(config))
301
+ if return_unused_kwargs:
302
+ return config, kwargs
303
+ else:
304
+ return config
305
+
306
+ @classmethod
307
+ def from_json_file(cls, json_file: str) -> "PretrainedConfig":
308
+ """
309
+ Constructs a `Config` from the path to a json file of parameters.
310
+
311
+ Args:
312
+ json_file (:obj:`string`):
313
+ Path to the JSON file containing the parameters.
314
+
315
+ Returns:
316
+ :class:`PretrainedConfig`: An instance of a configuration object
317
+
318
+ """
319
+ config_dict = cls._dict_from_json_file(json_file)
320
+ return cls(**config_dict)
321
+
322
+ @classmethod
323
+ def _dict_from_json_file(cls, json_file: str):
324
+ with open(json_file, "r", encoding="utf-8") as reader:
325
+ text = reader.read()
326
+ return json.loads(text)
327
+
328
+ def __eq__(self, other):
329
+ return self.__dict__ == other.__dict__
330
+
331
+ def __repr__(self):
332
+ return "{} {}".format(self.__class__.__name__, self.to_json_string())
333
+
334
+ def to_diff_dict(self):
335
+ """
336
+ Removes all attributes from config which correspond to the default
337
+ config attributes for better readability and serializes to a Python
338
+ dictionary.
339
+
340
+ Returns:
341
+ :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
342
+ """
343
+ config_dict = self.to_dict()
344
+
345
+ # get the default config dict
346
+ default_config_dict = PretrainedConfig().to_dict()
347
+
348
+ serializable_config_dict = {}
349
+
350
+ # only serialize values that differ from the default config
351
+ for key, value in config_dict.items():
352
+ if key not in default_config_dict or value != default_config_dict[key]:
353
+ serializable_config_dict[key] = value
354
+
355
+ return serializable_config_dict
356
+
357
+ def to_dict(self):
358
+ """
359
+ Serializes this instance to a Python dictionary.
360
+
361
+ Returns:
362
+ :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
363
+ """
364
+ output = copy.deepcopy(self.__dict__)
365
+ if hasattr(self.__class__, "model_type"):
366
+ output["model_type"] = self.__class__.model_type
367
+ return output
368
+
369
+ def to_json_string(self, use_diff=True):
370
+ """
371
+ Serializes this instance to a JSON string.
372
+
373
+ Args:
374
+ use_diff (:obj:`bool`):
375
+ If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string.
376
+
377
+ Returns:
378
+ :obj:`string`: String containing all the attributes that make up this configuration instance in JSON format.
379
+ """
380
+ if use_diff is True:
381
+ config_dict = self.to_diff_dict()
382
+ else:
383
+ config_dict = self.to_dict()
384
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
385
+
386
+ def to_json_file(self, json_file_path, use_diff=True):
387
+ """
388
+ Save this instance to a json file.
389
+
390
+ Args:
391
+ json_file_path (:obj:`string`):
392
+ Path to the JSON file in which this configuration instance's parameters will be saved.
393
+ use_diff (:obj:`bool`):
394
+ If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file.
395
+ """
396
+ with open(json_file_path, "w", encoding="utf-8") as writer:
397
+ writer.write(self.to_json_string(use_diff=use_diff))
398
+
399
+ def update(self, config_dict: Dict):
400
+ """
401
+ Updates attributes of this class
402
+ with attributes from `config_dict`.
403
+
404
+ Args:
405
+ :obj:`Dict[str, any]`: Dictionary of attributes that shall be updated for this class.
406
+ """
407
+ for key, value in config_dict.items():
408
+ setattr(self, key, value)
LAVT-RIS/bert/file_utils.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for working with the local dataset cache.
3
+ This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4
+ Copyright by the AllenNLP authors.
5
+ """
6
+
7
+ import fnmatch
8
+ import json
9
+ import logging
10
+ import os
11
+ import shutil
12
+ import sys
13
+ import tarfile
14
+ import tempfile
15
+ from contextlib import contextmanager
16
+ from functools import partial, wraps
17
+ from hashlib import sha256
18
+ from pathlib import Path
19
+ from typing import Dict, Optional, Union
20
+ from urllib.parse import urlparse
21
+ from zipfile import ZipFile, is_zipfile
22
+
23
+ import requests
24
+ from filelock import FileLock
25
+ from tqdm.auto import tqdm
26
+
27
+ #from . import __version__
28
+ __version__ = "3.0.2"
29
+
30
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
31
+
32
+ try:
33
+ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
34
+ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
35
+ if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
36
+ import torch
37
+
38
+ _torch_available = True # pylint: disable=invalid-name
39
+ logger.info("PyTorch version {} available.".format(torch.__version__))
40
+ else:
41
+ logger.info("Disabling PyTorch because USE_TF is set")
42
+ _torch_available = False
43
+ except ImportError:
44
+ _torch_available = False # pylint: disable=invalid-name
45
+
46
+ try:
47
+ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
48
+ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
49
+
50
+ if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
51
+ import tensorflow as tf
52
+
53
+ assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
54
+ _tf_available = True # pylint: disable=invalid-name
55
+ logger.info("TensorFlow version {} available.".format(tf.__version__))
56
+ else:
57
+ logger.info("Disabling Tensorflow because USE_TORCH is set")
58
+ _tf_available = False
59
+ except (ImportError, AssertionError):
60
+ _tf_available = False # pylint: disable=invalid-name
61
+
62
+
63
+ try:
64
+ from torch.hub import _get_torch_home
65
+
66
+ torch_cache_home = _get_torch_home()
67
+ except ImportError:
68
+ torch_cache_home = os.path.expanduser(
69
+ os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
70
+ )
71
+
72
+
73
+ try:
74
+ import torch_xla.core.xla_model as xm # noqa: F401
75
+
76
+ if _torch_available:
77
+ _torch_tpu_available = True # pylint: disable=
78
+ else:
79
+ _torch_tpu_available = False
80
+ except ImportError:
81
+ _torch_tpu_available = False
82
+
83
+
84
+ try:
85
+ import psutil # noqa: F401
86
+
87
+ _psutil_available = True
88
+
89
+ except ImportError:
90
+ _psutil_available = False
91
+
92
+
93
+ try:
94
+ import py3nvml # noqa: F401
95
+
96
+ _py3nvml_available = True
97
+
98
+ except ImportError:
99
+ _py3nvml_available = False
100
+
101
+
102
+ try:
103
+ from apex import amp # noqa: F401
104
+
105
+ _has_apex = True
106
+ except ImportError:
107
+ _has_apex = False
108
+
109
+ default_cache_path = os.path.join(torch_cache_home, "transformers")
110
+
111
+
112
+ PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
113
+ PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
114
+ TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
115
+
116
+ WEIGHTS_NAME = "pytorch_model.bin"
117
+ TF2_WEIGHTS_NAME = "tf_model.h5"
118
+ TF_WEIGHTS_NAME = "model.ckpt"
119
+ CONFIG_NAME = "config.json"
120
+ MODEL_CARD_NAME = "modelcard.json"
121
+
122
+
123
+ MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]]
124
+ DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
125
+ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
126
+
127
+ S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
128
+ CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
129
+
130
+
131
+ def is_torch_available():
132
+ return _torch_available
133
+
134
+
135
+ def is_tf_available():
136
+ return _tf_available
137
+
138
+
139
+ def is_torch_tpu_available():
140
+ return _torch_tpu_available
141
+
142
+
143
+ def is_psutil_available():
144
+ return _psutil_available
145
+
146
+
147
+ def is_py3nvml_available():
148
+ return _py3nvml_available
149
+
150
+
151
+ def is_apex_available():
152
+ return _has_apex
153
+
154
+
155
+ def add_start_docstrings(*docstr):
156
+ def docstring_decorator(fn):
157
+ fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
158
+ return fn
159
+
160
+ return docstring_decorator
161
+
162
+
163
+ def add_start_docstrings_to_callable(*docstr):
164
+ def docstring_decorator(fn):
165
+ class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0])
166
+ intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name)
167
+ note = r"""
168
+
169
+ .. note::
170
+ Although the recipe for forward pass needs to be defined within
171
+ this function, one should call the :class:`Module` instance afterwards
172
+ instead of this since the former takes care of running the
173
+ pre and post processing steps while the latter silently ignores them.
174
+ """
175
+ fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
176
+ return fn
177
+
178
+ return docstring_decorator
179
+
180
+
181
+ def add_end_docstrings(*docstr):
182
+ def docstring_decorator(fn):
183
+ fn.__doc__ = fn.__doc__ + "".join(docstr)
184
+ return fn
185
+
186
+ return docstring_decorator
187
+
188
+
189
+ PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
190
+ Example::
191
+
192
+ >>> from transformers import {tokenizer_class}, {model_class}
193
+ >>> import torch
194
+
195
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
196
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
197
+
198
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
199
+ >>> labels = torch.tensor([1] * inputs["input_ids"].size(1)).unsqueeze(0) # Batch size 1
200
+
201
+ >>> outputs = model(**inputs, labels=labels)
202
+ >>> loss, scores = outputs[:2]
203
+ """
204
+
205
+ PT_QUESTION_ANSWERING_SAMPLE = r"""
206
+ Example::
207
+
208
+ >>> from transformers import {tokenizer_class}, {model_class}
209
+ >>> import torch
210
+
211
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
212
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
213
+
214
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
215
+ >>> start_positions = torch.tensor([1])
216
+ >>> end_positions = torch.tensor([3])
217
+
218
+ >>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
219
+ >>> loss, start_scores, end_scores = outputs[:3]
220
+ """
221
+
222
+ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
223
+ Example::
224
+
225
+ >>> from transformers import {tokenizer_class}, {model_class}
226
+ >>> import torch
227
+
228
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
229
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
230
+
231
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
232
+ >>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
233
+ >>> outputs = model(**inputs, labels=labels)
234
+ >>> loss, logits = outputs[:2]
235
+ """
236
+
237
+ PT_MASKED_LM_SAMPLE = r"""
238
+ Example::
239
+
240
+ >>> from transformers import {tokenizer_class}, {model_class}
241
+ >>> import torch
242
+
243
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
244
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
245
+
246
+ >>> input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"]
247
+
248
+ >>> outputs = model(input_ids, labels=input_ids)
249
+ >>> loss, prediction_scores = outputs[:2]
250
+ """
251
+
252
+ PT_BASE_MODEL_SAMPLE = r"""
253
+ Example::
254
+
255
+ >>> from transformers import {tokenizer_class}, {model_class}
256
+ >>> import torch
257
+
258
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
259
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
260
+
261
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
262
+ >>> outputs = model(**inputs)
263
+
264
+ >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
265
+ """
266
+
267
+ PT_MULTIPLE_CHOICE_SAMPLE = r"""
268
+ Example::
269
+
270
+ >>> from transformers import {tokenizer_class}, {model_class}
271
+ >>> import torch
272
+
273
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
274
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
275
+
276
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
277
+ >>> choice0 = "It is eaten with a fork and a knife."
278
+ >>> choice1 = "It is eaten while held in the hand."
279
+ >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
280
+
281
+ >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True)
282
+ >>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, labels=labels) # batch size is 1
283
+
284
+ >>> # the linear classifier still needs to be trained
285
+ >>> loss, logits = outputs[:2]
286
+ """
287
+
288
+ PT_CAUSAL_LM_SAMPLE = r"""
289
+ Example::
290
+
291
+ >>> import torch
292
+ >>> from transformers import {tokenizer_class}, {model_class}
293
+
294
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
295
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
296
+
297
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
298
+ >>> outputs = model(**inputs, labels=inputs["input_ids"])
299
+ >>> loss, logits = outputs[:2]
300
+ """
301
+
302
+ TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
303
+ Example::
304
+
305
+ >>> from transformers import {tokenizer_class}, {model_class}
306
+ >>> import tensorflow as tf
307
+
308
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
309
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
310
+
311
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
312
+ >>> input_ids = inputs["input_ids"]
313
+ >>> inputs["labels"] = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
314
+
315
+ >>> outputs = model(inputs)
316
+ >>> loss, scores = outputs[:2]
317
+ """
318
+
319
+ TF_QUESTION_ANSWERING_SAMPLE = r"""
320
+ Example::
321
+
322
+ >>> from transformers import {tokenizer_class}, {model_class}
323
+ >>> import tensorflow as tf
324
+
325
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
326
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
327
+
328
+ >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
329
+ >>> input_dict = tokenizer(question, text, return_tensors='tf')
330
+ >>> start_scores, end_scores = model(input_dict)
331
+
332
+ >>> all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
333
+ >>> answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
334
+ """
335
+
336
+ TF_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
337
+ Example::
338
+
339
+ >>> from transformers import {tokenizer_class}, {model_class}
340
+ >>> import tensorflow as tf
341
+
342
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
343
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
344
+
345
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
346
+ >>> inputs["labels"] = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
347
+
348
+ >>> outputs = model(inputs)
349
+ >>> loss, logits = outputs[:2]
350
+ """
351
+
352
+ TF_MASKED_LM_SAMPLE = r"""
353
+ Example::
354
+ >>> from transformers import {tokenizer_class}, {model_class}
355
+ >>> import tensorflow as tf
356
+
357
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
358
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
359
+
360
+ >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
361
+
362
+ >>> outputs = model(input_ids)
363
+ >>> prediction_scores = outputs[0]
364
+ """
365
+
366
+ TF_BASE_MODEL_SAMPLE = r"""
367
+ Example::
368
+
369
+ >>> from transformers import {tokenizer_class}, {model_class}
370
+ >>> import tensorflow as tf
371
+
372
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
373
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
374
+
375
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
376
+ >>> outputs = model(inputs)
377
+
378
+ >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
379
+ """
380
+
381
+ TF_MULTIPLE_CHOICE_SAMPLE = r"""
382
+ Example::
383
+
384
+ >>> from transformers import {tokenizer_class}, {model_class}
385
+ >>> import tensorflow as tf
386
+
387
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
388
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
389
+
390
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
391
+ >>> choice0 = "It is eaten with a fork and a knife."
392
+ >>> choice1 = "It is eaten while held in the hand."
393
+
394
+ >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='tf', padding=True)
395
+ >>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}}
396
+ >>> outputs = model(inputs) # batch size is 1
397
+
398
+ >>> # the linear classifier still needs to be trained
399
+ >>> logits = outputs[0]
400
+ """
401
+
402
+ TF_CAUSAL_LM_SAMPLE = r"""
403
+ Example::
404
+
405
+ >>> from transformers import {tokenizer_class}, {model_class}
406
+ >>> import tensorflow as tf
407
+
408
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
409
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
410
+
411
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
412
+ >>> outputs = model(inputs)
413
+ >>> logits = outputs[0]
414
+ """
415
+
416
+
417
+ def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None):
418
+ def docstring_decorator(fn):
419
+ model_class = fn.__qualname__.split(".")[0]
420
+ is_tf_class = model_class[:2] == "TF"
421
+
422
+ if "SequenceClassification" in model_class:
423
+ code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE
424
+ elif "QuestionAnswering" in model_class:
425
+ code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE
426
+ elif "TokenClassification" in model_class:
427
+ code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE
428
+ elif "MultipleChoice" in model_class:
429
+ code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE
430
+ elif "MaskedLM" in model_class:
431
+ code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE
432
+ elif "LMHead" in model_class:
433
+ code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE
434
+ elif "Model" in model_class:
435
+ code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE
436
+ else:
437
+ raise ValueError(f"Docstring can't be built for model {model_class}")
438
+
439
+ built_doc = code_sample.format(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
440
+ fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + built_doc
441
+ return fn
442
+
443
+ return docstring_decorator
444
+
445
+
446
+ def is_remote_url(url_or_filename):
447
+ parsed = urlparse(url_or_filename)
448
+ return parsed.scheme in ("http", "https")
449
+
450
+
451
+ def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
452
+ """
453
+ Resolve a model identifier, and a file name, to a HF-hosted url
454
+ on either S3 or Cloudfront (a Content Delivery Network, or CDN).
455
+
456
+ Cloudfront is replicated over the globe so downloads are way faster
457
+ for the end user (and it also lowers our bandwidth costs). However, it
458
+ is more aggressively cached by default, so may not always reflect the
459
+ latest changes to the underlying file (default TTL is 24 hours).
460
+
461
+ In terms of client-side caching from this library, even though
462
+ Cloudfront relays the ETags from S3, using one or the other
463
+ (or switching from one to the other) will affect caching: cached files
464
+ are not shared between the two because the cached file's name contains
465
+ a hash of the url.
466
+ """
467
+ endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX
468
+ legacy_format = "/" not in model_id
469
+ if legacy_format:
470
+ return f"{endpoint}/{model_id}-{filename}"
471
+ else:
472
+ return f"{endpoint}/{model_id}/{filename}"
473
+
474
+
475
+ def url_to_filename(url, etag=None):
476
+ """
477
+ Convert `url` into a hashed filename in a repeatable way.
478
+ If `etag` is specified, append its hash to the url's, delimited
479
+ by a period.
480
+ If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name
481
+ so that TF 2.0 can identify it as a HDF5 file
482
+ (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
483
+ """
484
+ url_bytes = url.encode("utf-8")
485
+ url_hash = sha256(url_bytes)
486
+ filename = url_hash.hexdigest()
487
+
488
+ if etag:
489
+ etag_bytes = etag.encode("utf-8")
490
+ etag_hash = sha256(etag_bytes)
491
+ filename += "." + etag_hash.hexdigest()
492
+
493
+ if url.endswith(".h5"):
494
+ filename += ".h5"
495
+
496
+ return filename
497
+
498
+
499
+ def filename_to_url(filename, cache_dir=None):
500
+ """
501
+ Return the url and etag (which may be ``None``) stored for `filename`.
502
+ Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
503
+ """
504
+ if cache_dir is None:
505
+ cache_dir = TRANSFORMERS_CACHE
506
+ if isinstance(cache_dir, Path):
507
+ cache_dir = str(cache_dir)
508
+
509
+ cache_path = os.path.join(cache_dir, filename)
510
+ if not os.path.exists(cache_path):
511
+ raise EnvironmentError("file {} not found".format(cache_path))
512
+
513
+ meta_path = cache_path + ".json"
514
+ if not os.path.exists(meta_path):
515
+ raise EnvironmentError("file {} not found".format(meta_path))
516
+
517
+ with open(meta_path, encoding="utf-8") as meta_file:
518
+ metadata = json.load(meta_file)
519
+ url = metadata["url"]
520
+ etag = metadata["etag"]
521
+
522
+ return url, etag
523
+
524
+
525
+ def cached_path(
526
+ url_or_filename,
527
+ cache_dir=None,
528
+ force_download=False,
529
+ proxies=None,
530
+ resume_download=False,
531
+ user_agent: Union[Dict, str, None] = None,
532
+ extract_compressed_file=False,
533
+ force_extract=False,
534
+ local_files_only=False,
535
+ ) -> Optional[str]:
536
+ """
537
+ Given something that might be a URL (or might be a local path),
538
+ determine which. If it's a URL, download the file and cache it, and
539
+ return the path to the cached file. If it's already a local path,
540
+ make sure the file exists and then return the path.
541
+ Args:
542
+ cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
543
+ force_download: if True, re-dowload the file even if it's already cached in the cache dir.
544
+ resume_download: if True, resume the download if incompletly recieved file is found.
545
+ user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
546
+ extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
547
+ file in a folder along the archive.
548
+ force_extract: if True when extract_compressed_file is True and the archive was already extracted,
549
+ re-extract the archive and overide the folder where it was extracted.
550
+
551
+ Return:
552
+ None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
553
+ Local path (string) otherwise
554
+ """
555
+ if cache_dir is None:
556
+ cache_dir = TRANSFORMERS_CACHE
557
+ if isinstance(url_or_filename, Path):
558
+ url_or_filename = str(url_or_filename)
559
+ if isinstance(cache_dir, Path):
560
+ cache_dir = str(cache_dir)
561
+
562
+ if is_remote_url(url_or_filename):
563
+ # URL, so get it from the cache (downloading if necessary)
564
+ output_path = get_from_cache(
565
+ url_or_filename,
566
+ cache_dir=cache_dir,
567
+ force_download=force_download,
568
+ proxies=proxies,
569
+ resume_download=resume_download,
570
+ user_agent=user_agent,
571
+ local_files_only=local_files_only,
572
+ )
573
+ elif os.path.exists(url_or_filename):
574
+ # File, and it exists.
575
+ output_path = url_or_filename
576
+ elif urlparse(url_or_filename).scheme == "":
577
+ # File, but it doesn't exist.
578
+ raise EnvironmentError("file {} not found".format(url_or_filename))
579
+ else:
580
+ # Something unknown
581
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
582
+
583
+ if extract_compressed_file:
584
+ if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
585
+ return output_path
586
+
587
+ # Path where we extract compressed archives
588
+ # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
589
+ output_dir, output_file = os.path.split(output_path)
590
+ output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
591
+ output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
592
+
593
+ if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
594
+ return output_path_extracted
595
+
596
+ # Prevent parallel extractions
597
+ lock_path = output_path + ".lock"
598
+ with FileLock(lock_path):
599
+ shutil.rmtree(output_path_extracted, ignore_errors=True)
600
+ os.makedirs(output_path_extracted)
601
+ if is_zipfile(output_path):
602
+ with ZipFile(output_path, "r") as zip_file:
603
+ zip_file.extractall(output_path_extracted)
604
+ zip_file.close()
605
+ elif tarfile.is_tarfile(output_path):
606
+ tar_file = tarfile.open(output_path)
607
+ tar_file.extractall(output_path_extracted)
608
+ tar_file.close()
609
+ else:
610
+ raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
611
+
612
+ return output_path_extracted
613
+
614
+ return output_path
615
+
616
+
617
+ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
618
+ ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
619
+ if is_torch_available():
620
+ ua += "; torch/{}".format(torch.__version__)
621
+ if is_tf_available():
622
+ ua += "; tensorflow/{}".format(tf.__version__)
623
+ if isinstance(user_agent, dict):
624
+ ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
625
+ elif isinstance(user_agent, str):
626
+ ua += "; " + user_agent
627
+ headers = {"user-agent": ua}
628
+ if resume_size > 0:
629
+ headers["Range"] = "bytes=%d-" % (resume_size,)
630
+ response = requests.get(url, stream=True, proxies=proxies, headers=headers)
631
+ if response.status_code == 416: # Range not satisfiable
632
+ return
633
+ content_length = response.headers.get("Content-Length")
634
+ total = resume_size + int(content_length) if content_length is not None else None
635
+ progress = tqdm(
636
+ unit="B",
637
+ unit_scale=True,
638
+ total=total,
639
+ initial=resume_size,
640
+ desc="Downloading",
641
+ disable=bool(logger.getEffectiveLevel() == logging.NOTSET),
642
+ )
643
+ for chunk in response.iter_content(chunk_size=1024):
644
+ if chunk: # filter out keep-alive new chunks
645
+ progress.update(len(chunk))
646
+ temp_file.write(chunk)
647
+ progress.close()
648
+
649
+
650
+ def get_from_cache(
651
+ url,
652
+ cache_dir=None,
653
+ force_download=False,
654
+ proxies=None,
655
+ etag_timeout=10,
656
+ resume_download=False,
657
+ user_agent: Union[Dict, str, None] = None,
658
+ local_files_only=False,
659
+ ) -> Optional[str]:
660
+ """
661
+ Given a URL, look for the corresponding file in the local cache.
662
+ If it's not there, download it. Then return the path to the cached file.
663
+
664
+ Return:
665
+ None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
666
+ Local path (string) otherwise
667
+ """
668
+ if cache_dir is None:
669
+ cache_dir = TRANSFORMERS_CACHE
670
+ if isinstance(cache_dir, Path):
671
+ cache_dir = str(cache_dir)
672
+
673
+ os.makedirs(cache_dir, exist_ok=True)
674
+
675
+ etag = None
676
+ if not local_files_only:
677
+ try:
678
+ response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
679
+ if response.status_code == 200:
680
+ etag = response.headers.get("ETag")
681
+ except (EnvironmentError, requests.exceptions.Timeout):
682
+ # etag is already None
683
+ pass
684
+
685
+ filename = url_to_filename(url, etag)
686
+
687
+ # get cache path to put the file
688
+ cache_path = os.path.join(cache_dir, filename)
689
+
690
+ # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
691
+ # try to get the last downloaded one
692
+ if etag is None:
693
+ if os.path.exists(cache_path):
694
+ return cache_path
695
+ else:
696
+ matching_files = [
697
+ file
698
+ for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
699
+ if not file.endswith(".json") and not file.endswith(".lock")
700
+ ]
701
+ if len(matching_files) > 0:
702
+ return os.path.join(cache_dir, matching_files[-1])
703
+ else:
704
+ # If files cannot be found and local_files_only=True,
705
+ # the models might've been found if local_files_only=False
706
+ # Notify the user about that
707
+ if local_files_only:
708
+ raise ValueError(
709
+ "Cannot find the requested files in the cached path and outgoing traffic has been"
710
+ " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
711
+ " to False."
712
+ )
713
+ return None
714
+
715
+ # From now on, etag is not None.
716
+ if os.path.exists(cache_path) and not force_download:
717
+ return cache_path
718
+
719
+ # Prevent parallel downloads of the same file with a lock.
720
+ lock_path = cache_path + ".lock"
721
+ with FileLock(lock_path):
722
+
723
+ # If the download just completed while the lock was activated.
724
+ if os.path.exists(cache_path) and not force_download:
725
+ # Even if returning early like here, the lock will be released.
726
+ return cache_path
727
+
728
+ if resume_download:
729
+ incomplete_path = cache_path + ".incomplete"
730
+
731
+ @contextmanager
732
+ def _resumable_file_manager():
733
+ with open(incomplete_path, "a+b") as f:
734
+ yield f
735
+
736
+ temp_file_manager = _resumable_file_manager
737
+ if os.path.exists(incomplete_path):
738
+ resume_size = os.stat(incomplete_path).st_size
739
+ else:
740
+ resume_size = 0
741
+ else:
742
+ temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
743
+ resume_size = 0
744
+
745
+ # Download to temporary file, then copy to cache dir once finished.
746
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
747
+ with temp_file_manager() as temp_file:
748
+ logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
749
+
750
+ http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
751
+
752
+ logger.info("storing %s in cache at %s", url, cache_path)
753
+ os.replace(temp_file.name, cache_path)
754
+
755
+ logger.info("creating metadata file for %s", cache_path)
756
+ meta = {"url": url, "etag": etag}
757
+ meta_path = cache_path + ".json"
758
+ with open(meta_path, "w") as meta_file:
759
+ json.dump(meta, meta_file)
760
+
761
+ return cache_path
762
+
763
+
764
+ class cached_property(property):
765
+ """
766
+ Descriptor that mimics @property but caches output in member variable.
767
+
768
+ From tensorflow_datasets
769
+
770
+ Built-in in functools from Python 3.8.
771
+ """
772
+
773
+ def __get__(self, obj, objtype=None):
774
+ # See docs.python.org/3/howto/descriptor.html#properties
775
+ if obj is None:
776
+ return self
777
+ if self.fget is None:
778
+ raise AttributeError("unreadable attribute")
779
+ attr = "__cached_" + self.fget.__name__
780
+ cached = getattr(obj, attr, None)
781
+ if cached is None:
782
+ cached = self.fget(obj)
783
+ setattr(obj, attr, cached)
784
+ return cached
785
+
786
+
787
+ def torch_required(func):
788
+ # Chose a different decorator name than in tests so it's clear they are not the same.
789
+ @wraps(func)
790
+ def wrapper(*args, **kwargs):
791
+ if is_torch_available():
792
+ return func(*args, **kwargs)
793
+ else:
794
+ raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
795
+
796
+ return wrapper
797
+
798
+
799
+ def tf_required(func):
800
+ # Chose a different decorator name than in tests so it's clear they are not the same.
801
+ @wraps(func)
802
+ def wrapper(*args, **kwargs):
803
+ if is_tf_available():
804
+ return func(*args, **kwargs)
805
+ else:
806
+ raise ImportError(f"Method `{func.__name__}` requires TF.")
807
+
808
+ return wrapper
LAVT-RIS/bert/generation_utils.py ADDED
@@ -0,0 +1,993 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+ from typing import Iterable, Optional, Tuple
19
+
20
+ import torch
21
+ from torch import Tensor
22
+ from torch.nn import functional as F
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class GenerationMixin:
29
+ """
30
+ A class contraining all of the functions supporting generation, to be used as a mixin in PreTrainedModel.
31
+ """
32
+
33
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
34
+ return {"input_ids": input_ids}
35
+
36
+ def adjust_logits_during_generation(self, logits, **kwargs):
37
+ return logits
38
+
39
+ def _use_cache(self, outputs, use_cache):
40
+ """During generation, decide whether to pass the `past` variable to the next forward pass."""
41
+ if len(outputs) <= 1 or use_cache is False:
42
+ return False
43
+ if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
44
+ return False
45
+ return True
46
+
47
+ def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
48
+ """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
49
+ for i in range(batch_size * num_beams):
50
+ for previous_token in set(prev_output_tokens[i].tolist()):
51
+ # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
52
+ if lprobs[i, previous_token] < 0:
53
+ lprobs[i, previous_token] *= repetition_penalty
54
+ else:
55
+ lprobs[i, previous_token] /= repetition_penalty
56
+
57
+ def postprocess_next_token_scores(
58
+ self,
59
+ scores,
60
+ input_ids,
61
+ no_repeat_ngram_size,
62
+ bad_words_ids,
63
+ cur_len,
64
+ min_length,
65
+ max_length,
66
+ eos_token_id,
67
+ repetition_penalty,
68
+ batch_size,
69
+ num_beams,
70
+ ):
71
+ # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
72
+ if repetition_penalty != 1.0:
73
+ self.enforce_repetition_penalty_(
74
+ scores, batch_size, num_beams, input_ids, repetition_penalty,
75
+ )
76
+
77
+ # set eos token prob to zero if min_length is not reached
78
+ if eos_token_id is not None and cur_len < min_length:
79
+ scores[:, eos_token_id] = -float("inf")
80
+
81
+ if no_repeat_ngram_size > 0:
82
+ # calculate a list of banned tokens to prevent repetitively generating the same ngrams
83
+ num_batch_hypotheses = batch_size * num_beams
84
+ # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
85
+ banned_batch_tokens = calc_banned_ngram_tokens(
86
+ input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
87
+ )
88
+ for i, banned_tokens in enumerate(banned_batch_tokens):
89
+ scores[i, banned_tokens] = -float("inf")
90
+
91
+ if bad_words_ids is not None:
92
+ # calculate a list of banned tokens according to bad words
93
+ banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
94
+
95
+ for i, banned_tokens in enumerate(banned_tokens):
96
+ scores[i, banned_tokens] = -float("inf")
97
+
98
+ return scores
99
+
100
+ @torch.no_grad()
101
+ def generate(
102
+ self,
103
+ input_ids: Optional[torch.LongTensor] = None,
104
+ max_length: Optional[int] = None,
105
+ min_length: Optional[int] = None,
106
+ do_sample: Optional[bool] = None,
107
+ early_stopping: Optional[bool] = None,
108
+ num_beams: Optional[int] = None,
109
+ temperature: Optional[float] = None,
110
+ top_k: Optional[int] = None,
111
+ top_p: Optional[float] = None,
112
+ repetition_penalty: Optional[float] = None,
113
+ bad_words_ids: Optional[Iterable[int]] = None,
114
+ bos_token_id: Optional[int] = None,
115
+ pad_token_id: Optional[int] = None,
116
+ eos_token_id: Optional[int] = None,
117
+ length_penalty: Optional[float] = None,
118
+ no_repeat_ngram_size: Optional[int] = None,
119
+ num_return_sequences: Optional[int] = None,
120
+ attention_mask: Optional[torch.LongTensor] = None,
121
+ decoder_start_token_id: Optional[int] = None,
122
+ use_cache: Optional[bool] = None,
123
+ **model_specific_kwargs
124
+ ) -> torch.LongTensor:
125
+ r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
126
+
127
+ Adapted in part from `Facebook's XLM beam search code`_.
128
+
129
+ .. _`Facebook's XLM beam search code`:
130
+ https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
131
+
132
+
133
+ Parameters:
134
+
135
+ input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
136
+ The sequence used as a prompt for the generation. If `None` the method initializes
137
+ it as an empty `torch.LongTensor` of shape `(1,)`.
138
+
139
+ max_length: (`optional`) int
140
+ The max length of the sequence to be generated. Between `min_length` and infinity. Default to 20.
141
+
142
+ min_length: (`optional`) int
143
+ The min length of the sequence to be generated. Between 0 and infinity. Default to 0.
144
+
145
+ do_sample: (`optional`) bool
146
+ If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
147
+
148
+ early_stopping: (`optional`) bool
149
+ if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
150
+
151
+ num_beams: (`optional`) int
152
+ Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
153
+
154
+ temperature: (`optional`) float
155
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
156
+
157
+ top_k: (`optional`) int
158
+ The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
159
+
160
+ top_p: (`optional`) float
161
+ The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
162
+
163
+ repetition_penalty: (`optional`) float
164
+ The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
165
+
166
+ pad_token_id: (`optional`) int
167
+ Padding token. Default to specicic model pad_token_id or None if it does not exist.
168
+
169
+ bos_token_id: (`optional`) int
170
+ BOS token. Defaults to `bos_token_id` as defined in the models config.
171
+
172
+ eos_token_id: (`optional`) int
173
+ EOS token. Defaults to `eos_token_id` as defined in the models config.
174
+
175
+ length_penalty: (`optional`) float
176
+ Exponential penalty to the length. Default to 1.
177
+
178
+ no_repeat_ngram_size: (`optional`) int
179
+ If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.
180
+ bad_words_ids: (`optional`) list of lists of int
181
+ `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
182
+
183
+ num_return_sequences: (`optional`) int
184
+ The number of independently computed returned sequences for each element in the batch. Default to 1.
185
+
186
+ attention_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids`
187
+ Mask to avoid performing attention on padding token indices.
188
+ Mask values selected in ``[0, 1]``:
189
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
190
+ Defaults to `None`.
191
+
192
+ `What are attention masks? <../glossary.html#attention-mask>`__
193
+
194
+ decoder_start_token_id=None: (`optional`) int
195
+ If an encoder-decoder model starts decoding with a different token than BOS.
196
+ Defaults to `None` and is changed to `BOS` later.
197
+
198
+ use_cache: (`optional`) bool
199
+ If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
200
+
201
+ model_specific_kwargs: (`optional`) dict
202
+ Additional model specific kwargs will be forwarded to the `forward` function of the model.
203
+
204
+ Return:
205
+
206
+ output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
207
+ sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`
208
+
209
+ Examples::
210
+
211
+ tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
212
+ model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
213
+ outputs = model.generate(max_length=40) # do greedy decoding
214
+ print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
215
+
216
+ tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
217
+ model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
218
+ input_context = 'The dog'
219
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
220
+ outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
221
+ for i in range(3): # 3 output sequences were generated
222
+ print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
223
+
224
+ tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
225
+ model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
226
+ input_context = 'The dog'
227
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
228
+ outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling
229
+ for i in range(3): # 3 output sequences were generated
230
+ print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
231
+
232
+ tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
233
+ model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
234
+ input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
235
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
236
+ outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
237
+ print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
238
+
239
+ tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
240
+ model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
241
+ input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
242
+ bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
243
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
244
+ outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
245
+ """
246
+
247
+ # We cannot generate if the model does not have a LM head
248
+ if self.get_output_embeddings() is None:
249
+ raise AttributeError(
250
+ "You tried to generate sequences with a model that does not have a LM Head."
251
+ "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
252
+ )
253
+
254
+ max_length = max_length if max_length is not None else self.config.max_length
255
+ min_length = min_length if min_length is not None else self.config.min_length
256
+ do_sample = do_sample if do_sample is not None else self.config.do_sample
257
+ early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
258
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
259
+ num_beams = num_beams if num_beams is not None else self.config.num_beams
260
+ temperature = temperature if temperature is not None else self.config.temperature
261
+ top_k = top_k if top_k is not None else self.config.top_k
262
+ top_p = top_p if top_p is not None else self.config.top_p
263
+ repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
264
+ bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
265
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
266
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
267
+ length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
268
+ no_repeat_ngram_size = (
269
+ no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
270
+ )
271
+ bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
272
+ num_return_sequences = (
273
+ num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
274
+ )
275
+ decoder_start_token_id = (
276
+ decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
277
+ )
278
+
279
+ if input_ids is not None:
280
+ batch_size = input_ids.shape[0] # overriden by the input batch_size
281
+ else:
282
+ batch_size = 1
283
+
284
+ assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
285
+ assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
286
+ assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
287
+ assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
288
+ assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
289
+ assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
290
+ assert temperature > 0, "`temperature` should be strictly positive."
291
+ assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
292
+ assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
293
+ assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
294
+ assert input_ids is not None or (
295
+ isinstance(bos_token_id, int) and bos_token_id >= 0
296
+ ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
297
+ assert pad_token_id is None or (
298
+ isinstance(pad_token_id, int) and (pad_token_id >= 0)
299
+ ), "`pad_token_id` should be a positive integer."
300
+ assert (eos_token_id is None) or (
301
+ isinstance(eos_token_id, int) and (eos_token_id >= 0)
302
+ ), "`eos_token_id` should be a positive integer."
303
+ assert length_penalty > 0, "`length_penalty` should be strictly positive."
304
+ assert (
305
+ isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
306
+ ), "`no_repeat_ngram_size` should be a positive integer."
307
+ assert (
308
+ isinstance(num_return_sequences, int) and num_return_sequences > 0
309
+ ), "`num_return_sequences` should be a strictly positive integer."
310
+ assert (
311
+ bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
312
+ ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
313
+
314
+ if input_ids is None:
315
+ assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
316
+ "you should either supply a context to complete as `input_ids` input "
317
+ "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
318
+ )
319
+ input_ids = torch.full(
320
+ (batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device,
321
+ )
322
+ else:
323
+ assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
324
+
325
+ # not allow to duplicate outputs when greedy decoding
326
+ if do_sample is False:
327
+ if num_beams == 1:
328
+ # no_beam_search greedy generation conditions
329
+ assert (
330
+ num_return_sequences == 1
331
+ ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
332
+
333
+ else:
334
+ # beam_search greedy generation conditions
335
+ assert (
336
+ num_beams >= num_return_sequences
337
+ ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
338
+
339
+ # create attention mask if necessary
340
+ # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
341
+ if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
342
+ attention_mask = input_ids.ne(pad_token_id).long()
343
+ elif attention_mask is None:
344
+ attention_mask = input_ids.new_ones(input_ids.shape)
345
+
346
+ # set pad_token_id to eos_token_id if not set. Important that this is done after
347
+ # attention_mask is created
348
+ if pad_token_id is None and eos_token_id is not None:
349
+ logger.warning(
350
+ "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
351
+ )
352
+ pad_token_id = eos_token_id
353
+
354
+ # current position and vocab size
355
+ if hasattr(self.config, "vocab_size"):
356
+ vocab_size = self.config.vocab_size
357
+ elif (
358
+ self.config.is_encoder_decoder
359
+ and hasattr(self.config, "decoder")
360
+ and hasattr(self.config.decoder, "vocab_size")
361
+ ):
362
+ vocab_size = self.config.decoder.vocab_size
363
+
364
+ # set effective batch size and effective batch multiplier according to do_sample
365
+ if do_sample:
366
+ effective_batch_size = batch_size * num_return_sequences
367
+ effective_batch_mult = num_return_sequences
368
+ else:
369
+ effective_batch_size = batch_size
370
+ effective_batch_mult = 1
371
+
372
+ if self.config.is_encoder_decoder:
373
+ if decoder_start_token_id is None:
374
+ decoder_start_token_id = bos_token_id
375
+
376
+ assert (
377
+ decoder_start_token_id is not None
378
+ ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
379
+ assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
380
+ assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
381
+
382
+ # get encoder and store encoder outputs
383
+ encoder = self.get_encoder()
384
+
385
+ encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
386
+
387
+ # Expand input ids if num_beams > 1 or num_return_sequences > 1
388
+ if num_return_sequences > 1 or num_beams > 1:
389
+ input_ids_len = input_ids.shape[-1]
390
+ input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
391
+ attention_mask = attention_mask.unsqueeze(1).expand(
392
+ batch_size, effective_batch_mult * num_beams, input_ids_len
393
+ )
394
+
395
+ input_ids = input_ids.contiguous().view(
396
+ effective_batch_size * num_beams, input_ids_len
397
+ ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
398
+ attention_mask = attention_mask.contiguous().view(
399
+ effective_batch_size * num_beams, input_ids_len
400
+ ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
401
+
402
+ if self.config.is_encoder_decoder:
403
+ # create empty decoder_input_ids
404
+ input_ids = torch.full(
405
+ (effective_batch_size * num_beams, 1),
406
+ decoder_start_token_id,
407
+ dtype=torch.long,
408
+ device=next(self.parameters()).device,
409
+ )
410
+ cur_len = 1
411
+
412
+ assert (
413
+ batch_size == encoder_outputs[0].shape[0]
414
+ ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
415
+
416
+ # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
417
+ expanded_batch_idxs = (
418
+ torch.arange(batch_size)
419
+ .view(-1, 1)
420
+ .repeat(1, num_beams * effective_batch_mult)
421
+ .view(-1)
422
+ .to(input_ids.device)
423
+ )
424
+ # expand encoder_outputs
425
+ encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
426
+
427
+ else:
428
+ encoder_outputs = None
429
+ cur_len = input_ids.shape[-1]
430
+
431
+ assert (
432
+ cur_len < max_length
433
+ ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
434
+
435
+ if num_beams > 1:
436
+ output = self._generate_beam_search(
437
+ input_ids,
438
+ cur_len=cur_len,
439
+ max_length=max_length,
440
+ min_length=min_length,
441
+ do_sample=do_sample,
442
+ early_stopping=early_stopping,
443
+ temperature=temperature,
444
+ top_k=top_k,
445
+ top_p=top_p,
446
+ repetition_penalty=repetition_penalty,
447
+ no_repeat_ngram_size=no_repeat_ngram_size,
448
+ bad_words_ids=bad_words_ids,
449
+ pad_token_id=pad_token_id,
450
+ eos_token_id=eos_token_id,
451
+ batch_size=effective_batch_size,
452
+ num_return_sequences=num_return_sequences,
453
+ length_penalty=length_penalty,
454
+ num_beams=num_beams,
455
+ vocab_size=vocab_size,
456
+ encoder_outputs=encoder_outputs,
457
+ attention_mask=attention_mask,
458
+ use_cache=use_cache,
459
+ model_specific_kwargs=model_specific_kwargs,
460
+ )
461
+ else:
462
+ output = self._generate_no_beam_search(
463
+ input_ids,
464
+ cur_len=cur_len,
465
+ max_length=max_length,
466
+ min_length=min_length,
467
+ do_sample=do_sample,
468
+ temperature=temperature,
469
+ top_k=top_k,
470
+ top_p=top_p,
471
+ repetition_penalty=repetition_penalty,
472
+ no_repeat_ngram_size=no_repeat_ngram_size,
473
+ bad_words_ids=bad_words_ids,
474
+ pad_token_id=pad_token_id,
475
+ eos_token_id=eos_token_id,
476
+ batch_size=effective_batch_size,
477
+ encoder_outputs=encoder_outputs,
478
+ attention_mask=attention_mask,
479
+ use_cache=use_cache,
480
+ model_specific_kwargs=model_specific_kwargs,
481
+ )
482
+
483
+ return output
484
+
485
+ def _generate_no_beam_search(
486
+ self,
487
+ input_ids,
488
+ cur_len,
489
+ max_length,
490
+ min_length,
491
+ do_sample,
492
+ temperature,
493
+ top_k,
494
+ top_p,
495
+ repetition_penalty,
496
+ no_repeat_ngram_size,
497
+ bad_words_ids,
498
+ pad_token_id,
499
+ eos_token_id,
500
+ batch_size,
501
+ encoder_outputs,
502
+ attention_mask,
503
+ use_cache,
504
+ model_specific_kwargs,
505
+ ):
506
+ """ Generate sequences for each example without beam search (num_beams == 1).
507
+ All returned sequence are generated independantly.
508
+ """
509
+ # length of generated sentences / unfinished sentences
510
+ unfinished_sents = input_ids.new(batch_size).fill_(1)
511
+ sent_lengths = input_ids.new(batch_size).fill_(max_length)
512
+
513
+ past = (encoder_outputs, None) if encoder_outputs is not None else None
514
+
515
+ while cur_len < max_length:
516
+ model_inputs = self.prepare_inputs_for_generation(
517
+ input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
518
+ )
519
+
520
+ outputs = self(**model_inputs)
521
+ next_token_logits = outputs[0][:, -1, :]
522
+
523
+ scores = self.postprocess_next_token_scores(
524
+ scores=next_token_logits,
525
+ input_ids=input_ids,
526
+ no_repeat_ngram_size=no_repeat_ngram_size,
527
+ bad_words_ids=bad_words_ids,
528
+ cur_len=cur_len,
529
+ min_length=min_length,
530
+ max_length=max_length,
531
+ eos_token_id=eos_token_id,
532
+ repetition_penalty=repetition_penalty,
533
+ batch_size=batch_size,
534
+ num_beams=1,
535
+ )
536
+
537
+ # if model has past, then set the past variable to speed up decoding
538
+ if self._use_cache(outputs, use_cache):
539
+ past = outputs[1]
540
+
541
+ if do_sample:
542
+ # Temperature (higher temperature => more likely to sample low probability tokens)
543
+ if temperature != 1.0:
544
+ scores = scores / temperature
545
+ # Top-p/top-k filtering
546
+ next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
547
+ # Sample
548
+ probs = F.softmax(next_token_logscores, dim=-1)
549
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
550
+ else:
551
+ # Greedy decoding
552
+ next_token = torch.argmax(next_token_logits, dim=-1)
553
+
554
+ # update generations and finished sentences
555
+ if eos_token_id is not None:
556
+ # pad finished sentences if eos_token_id exist
557
+ tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
558
+ else:
559
+ tokens_to_add = next_token
560
+
561
+ # add token and increase length by one
562
+ input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
563
+ cur_len = cur_len + 1
564
+
565
+ if eos_token_id is not None:
566
+ eos_in_sents = tokens_to_add == eos_token_id
567
+ # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
568
+ is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
569
+ sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
570
+ # unfinished_sents is set to zero if eos in sentence
571
+ unfinished_sents.mul_((~eos_in_sents).long())
572
+
573
+ # stop when there is a </s> in each sentence, or if we exceed the maximul length
574
+ if unfinished_sents.max() == 0:
575
+ break
576
+
577
+ # extend attention_mask for new generated input if only decoder
578
+ if self.config.is_encoder_decoder is False:
579
+ attention_mask = torch.cat(
580
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
581
+ )
582
+
583
+ return input_ids
584
+
585
+ def _generate_beam_search(
586
+ self,
587
+ input_ids,
588
+ cur_len,
589
+ max_length,
590
+ min_length,
591
+ do_sample,
592
+ early_stopping,
593
+ temperature,
594
+ top_k,
595
+ top_p,
596
+ repetition_penalty,
597
+ no_repeat_ngram_size,
598
+ bad_words_ids,
599
+ pad_token_id,
600
+ eos_token_id,
601
+ batch_size,
602
+ num_return_sequences,
603
+ length_penalty,
604
+ num_beams,
605
+ vocab_size,
606
+ encoder_outputs,
607
+ attention_mask,
608
+ use_cache,
609
+ model_specific_kwargs,
610
+ ):
611
+ """ Generate sequences for each example with beam search.
612
+ """
613
+
614
+ # generated hypotheses
615
+ generated_hyps = [
616
+ BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
617
+ for _ in range(batch_size)
618
+ ]
619
+
620
+ # scores for each sentence in the beam
621
+ beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
622
+
623
+ # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
624
+ if do_sample is False:
625
+ beam_scores[:, 1:] = -1e9
626
+ beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
627
+
628
+ # cache compute states
629
+ past = (encoder_outputs, None) if encoder_outputs is not None else None
630
+
631
+ # done sentences
632
+ done = [False for _ in range(batch_size)]
633
+
634
+ while cur_len < max_length:
635
+ model_inputs = self.prepare_inputs_for_generation(
636
+ input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
637
+ )
638
+ outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
639
+ next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
640
+
641
+ # if model has past, then set the past variable to speed up decoding
642
+ if self._use_cache(outputs, use_cache):
643
+ past = outputs[1]
644
+ if self.config.is_encoder_decoder and do_sample is False:
645
+ # TODO (PVP) still a bit hacky here - there might be a better solution
646
+ next_token_logits = self.adjust_logits_during_generation(
647
+ next_token_logits, cur_len=cur_len, max_length=max_length
648
+ )
649
+
650
+ scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
651
+
652
+ scores = self.postprocess_next_token_scores(
653
+ scores=scores,
654
+ input_ids=input_ids,
655
+ no_repeat_ngram_size=no_repeat_ngram_size,
656
+ bad_words_ids=bad_words_ids,
657
+ cur_len=cur_len,
658
+ min_length=min_length,
659
+ max_length=max_length,
660
+ eos_token_id=eos_token_id,
661
+ repetition_penalty=repetition_penalty,
662
+ batch_size=batch_size,
663
+ num_beams=num_beams,
664
+ )
665
+
666
+ assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
667
+ scores.shape, (batch_size * num_beams, vocab_size)
668
+ )
669
+
670
+ if do_sample:
671
+ _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
672
+ # Temperature
673
+ if temperature != 1.0:
674
+ _scores = _scores / temperature
675
+ # Top-p/top-k filtering
676
+ _scores = top_k_top_p_filtering(
677
+ _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
678
+ ) # (batch_size * num_beams, vocab_size)
679
+ # re-organize to group the beam together to sample from all beam_idxs
680
+ _scores = _scores.contiguous().view(
681
+ batch_size, num_beams * vocab_size
682
+ ) # (batch_size, num_beams * vocab_size)
683
+
684
+ # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
685
+ probs = F.softmax(_scores, dim=-1)
686
+ next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2)
687
+ # Compute next scores
688
+ next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
689
+ # sort the sampled vector to make sure that the first num_beams samples are the best
690
+ next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
691
+ next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
692
+
693
+ else:
694
+ next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
695
+
696
+ # re-organize to group the beam together (we are keeping top hypothesis accross beams)
697
+ next_scores = next_scores.view(
698
+ batch_size, num_beams * vocab_size
699
+ ) # (batch_size, num_beams * vocab_size)
700
+
701
+ next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
702
+
703
+ assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
704
+
705
+ # next batch beam content
706
+ next_batch_beam = []
707
+
708
+ # for each sentence
709
+ for batch_idx in range(batch_size):
710
+
711
+ # if we are done with this sentence, add a pad token
712
+ if done[batch_idx]:
713
+ assert (
714
+ len(generated_hyps[batch_idx]) >= num_beams
715
+ ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
716
+ assert (
717
+ eos_token_id is not None and pad_token_id is not None
718
+ ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
719
+ next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
720
+ continue
721
+
722
+ # next sentence beam content, this will get added to next_batch_beam
723
+ next_sent_beam = []
724
+
725
+ # next tokens for this sentence
726
+ for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
727
+ zip(next_tokens[batch_idx], next_scores[batch_idx])
728
+ ):
729
+ # get beam and token IDs
730
+ beam_id = beam_token_id // vocab_size
731
+ token_id = beam_token_id % vocab_size
732
+
733
+ effective_beam_id = batch_idx * num_beams + beam_id
734
+ # add to generated hypotheses if end of sentence
735
+ if (eos_token_id is not None) and (token_id.item() == eos_token_id):
736
+ # if beam_token does not belong to top num_beams tokens, it should not be added
737
+ is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
738
+ if is_beam_token_worse_than_top_num_beams:
739
+ continue
740
+ generated_hyps[batch_idx].add(
741
+ input_ids[effective_beam_id].clone(), beam_token_score.item(),
742
+ )
743
+ else:
744
+ # add next predicted token since it is not eos_token
745
+ next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
746
+
747
+ # once the beam for next step is full, don't add more tokens to it.
748
+ if len(next_sent_beam) == num_beams:
749
+ break
750
+
751
+ # Check if we are done so that we can save a pad step if all(done)
752
+ done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
753
+ next_scores[batch_idx].max().item(), cur_len
754
+ )
755
+
756
+ # update next beam content
757
+ assert len(next_sent_beam) == num_beams, "Beam should always be full"
758
+ next_batch_beam.extend(next_sent_beam)
759
+ assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
760
+
761
+ # stop when we are done with each sentence
762
+ if all(done):
763
+ break
764
+
765
+ # sanity check / prepare next batch
766
+ assert len(next_batch_beam) == batch_size * num_beams
767
+ beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
768
+ beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
769
+ beam_idx = input_ids.new([x[2] for x in next_batch_beam])
770
+
771
+ # re-order batch and update current length
772
+ input_ids = input_ids[beam_idx, :]
773
+ input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
774
+ cur_len = cur_len + 1
775
+
776
+ # re-order internal states
777
+ if past is not None:
778
+ past = self._reorder_cache(past, beam_idx)
779
+
780
+ # extend attention_mask for new generated input if only decoder
781
+ if self.config.is_encoder_decoder is False:
782
+ attention_mask = torch.cat(
783
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
784
+ )
785
+
786
+ # finalize all open beam hypotheses and add to generated hypotheses
787
+ for batch_idx in range(batch_size):
788
+ if done[batch_idx]:
789
+ continue
790
+
791
+ # test that beam scores match previously calculated scores if not eos and batch_idx not done
792
+ if eos_token_id is not None and all(
793
+ (token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx]
794
+ ):
795
+ assert torch.all(
796
+ next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
797
+ ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
798
+ next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx],
799
+ )
800
+
801
+ # need to add best num_beams hypotheses to generated hyps
802
+ for beam_id in range(num_beams):
803
+ effective_beam_id = batch_idx * num_beams + beam_id
804
+ final_score = beam_scores[effective_beam_id].item()
805
+ final_tokens = input_ids[effective_beam_id]
806
+ generated_hyps[batch_idx].add(final_tokens, final_score)
807
+
808
+ # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
809
+ output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
810
+ output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
811
+
812
+ # select the best hypotheses
813
+ sent_lengths = input_ids.new(output_batch_size)
814
+ best = []
815
+
816
+ # retrieve best hypotheses
817
+ for i, hypotheses in enumerate(generated_hyps):
818
+ sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
819
+ for j in range(output_num_return_sequences_per_batch):
820
+ effective_batch_idx = output_num_return_sequences_per_batch * i + j
821
+ best_hyp = sorted_hyps.pop()[1]
822
+ sent_lengths[effective_batch_idx] = len(best_hyp)
823
+ best.append(best_hyp)
824
+
825
+ # shorter batches are padded
826
+ if sent_lengths.min().item() != sent_lengths.max().item():
827
+ assert pad_token_id is not None, "`Pad_token_id` has to be defined"
828
+ sent_max_len = min(sent_lengths.max().item() + 1, max_length)
829
+ decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
830
+
831
+ # fill with hypothesis and eos_token_id if necessary
832
+ for i, hypo in enumerate(best):
833
+ decoded[i, : sent_lengths[i]] = hypo
834
+ if sent_lengths[i] < max_length:
835
+ decoded[i, sent_lengths[i]] = eos_token_id
836
+ else:
837
+ # none of the hypotheses have an eos_token
838
+ assert (len(hypo) == max_length for hypo in best)
839
+ decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
840
+
841
+ return decoded
842
+
843
+ @staticmethod
844
+ def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
845
+ return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
846
+
847
+
848
+ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
849
+ """Copied from fairseq for no_repeat_ngram in beam_search"""
850
+ if cur_len + 1 < no_repeat_ngram_size:
851
+ # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
852
+ return [[] for _ in range(num_hypos)]
853
+ generated_ngrams = [{} for _ in range(num_hypos)]
854
+ for idx in range(num_hypos):
855
+ gen_tokens = prev_input_ids[idx].tolist()
856
+ generated_ngram = generated_ngrams[idx]
857
+ for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
858
+ prev_ngram_tuple = tuple(ngram[:-1])
859
+ generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
860
+
861
+ def _get_generated_ngrams(hypo_idx):
862
+ # Before decoding the next token, prevent decoding of ngrams that have already appeared
863
+ start_idx = cur_len + 1 - no_repeat_ngram_size
864
+ ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
865
+ return generated_ngrams[hypo_idx].get(ngram_idx, [])
866
+
867
+ banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
868
+ return banned_tokens
869
+
870
+
871
+ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
872
+ banned_tokens = []
873
+
874
+ def _tokens_match(prev_tokens, tokens):
875
+ if len(tokens) == 0:
876
+ # if bad word tokens is just one token always ban it
877
+ return True
878
+ if len(tokens) > len(prev_input_ids):
879
+ # if bad word tokens are longer then prev input_ids they can't be equal
880
+ return False
881
+
882
+ if prev_tokens[-len(tokens) :] == tokens:
883
+ # if tokens match
884
+ return True
885
+ else:
886
+ return False
887
+
888
+ for prev_input_ids_slice in prev_input_ids:
889
+ banned_tokens_slice = []
890
+
891
+ for banned_token_seq in bad_words_ids:
892
+ assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
893
+ bad_words_ids
894
+ )
895
+
896
+ if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False:
897
+ # if tokens do not match continue
898
+ continue
899
+
900
+ banned_tokens_slice.append(banned_token_seq[-1])
901
+
902
+ banned_tokens.append(banned_tokens_slice)
903
+
904
+ return banned_tokens
905
+
906
+
907
+ def top_k_top_p_filtering(
908
+ logits: Tensor,
909
+ top_k: int = 0,
910
+ top_p: float = 1.0,
911
+ filter_value: float = -float("Inf"),
912
+ min_tokens_to_keep: int = 1,
913
+ ) -> Tensor:
914
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
915
+ Args:
916
+ logits: logits distribution shape (batch size, vocabulary size)
917
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
918
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
919
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
920
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
921
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
922
+ """
923
+ if top_k > 0:
924
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
925
+ # Remove all tokens with a probability less than the last token of the top-k
926
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
927
+ logits[indices_to_remove] = filter_value
928
+
929
+ if top_p < 1.0:
930
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
931
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
932
+
933
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
934
+ sorted_indices_to_remove = cumulative_probs > top_p
935
+ if min_tokens_to_keep > 1:
936
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
937
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
938
+ # Shift the indices to the right to keep also the first token above the threshold
939
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
940
+ sorted_indices_to_remove[..., 0] = 0
941
+
942
+ # scatter sorted tensors to original indexing
943
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
944
+ logits[indices_to_remove] = filter_value
945
+ return logits
946
+
947
+
948
+ class BeamHypotheses(object):
949
+ def __init__(self, num_beams, max_length, length_penalty, early_stopping):
950
+ """
951
+ Initialize n-best list of hypotheses.
952
+ """
953
+ self.max_length = max_length - 1 # ignoring bos_token
954
+ self.length_penalty = length_penalty
955
+ self.early_stopping = early_stopping
956
+ self.num_beams = num_beams
957
+ self.beams = []
958
+ self.worst_score = 1e9
959
+
960
+ def __len__(self):
961
+ """
962
+ Number of hypotheses in the list.
963
+ """
964
+ return len(self.beams)
965
+
966
+ def add(self, hyp, sum_logprobs):
967
+ """
968
+ Add a new hypothesis to the list.
969
+ """
970
+ score = sum_logprobs / len(hyp) ** self.length_penalty
971
+ if len(self) < self.num_beams or score > self.worst_score:
972
+ self.beams.append((score, hyp))
973
+ if len(self) > self.num_beams:
974
+ sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
975
+ del self.beams[sorted_scores[0][1]]
976
+ self.worst_score = sorted_scores[1][0]
977
+ else:
978
+ self.worst_score = min(score, self.worst_score)
979
+
980
+ def is_done(self, best_sum_logprobs, cur_len):
981
+ """
982
+ If there are enough hypotheses and that none of the hypotheses being generated
983
+ can become better than the worst one in the heap, then we are done with this sentence.
984
+ """
985
+
986
+ if len(self) < self.num_beams:
987
+ return False
988
+ elif self.early_stopping:
989
+ return True
990
+ else:
991
+ cur_score = best_sum_logprobs / cur_len ** self.length_penalty
992
+ ret = self.worst_score >= cur_score
993
+ return ret
LAVT-RIS/bert/modeling_bert.py ADDED
@@ -0,0 +1,1569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model. """
17
+
18
+
19
+ import logging
20
+ import math
21
+ import os
22
+ import warnings
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import CrossEntropyLoss, MSELoss
28
+
29
+ from .activations import gelu, gelu_new, swish
30
+ from .configuration_bert import BertConfig
31
+ from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
32
+ from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
33
+
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ _TOKENIZER_FOR_DOC = "BertTokenizer"
38
+
39
+ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
40
+ "bert-base-uncased",
41
+ "bert-large-uncased",
42
+ "bert-base-cased",
43
+ "bert-large-cased",
44
+ "bert-base-multilingual-uncased",
45
+ "bert-base-multilingual-cased",
46
+ "bert-base-chinese",
47
+ "bert-base-german-cased",
48
+ "bert-large-uncased-whole-word-masking",
49
+ "bert-large-cased-whole-word-masking",
50
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
51
+ "bert-large-cased-whole-word-masking-finetuned-squad",
52
+ "bert-base-cased-finetuned-mrpc",
53
+ "bert-base-german-dbmdz-cased",
54
+ "bert-base-german-dbmdz-uncased",
55
+ "cl-tohoku/bert-base-japanese",
56
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
57
+ "cl-tohoku/bert-base-japanese-char",
58
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking",
59
+ "TurkuNLP/bert-base-finnish-cased-v1",
60
+ "TurkuNLP/bert-base-finnish-uncased-v1",
61
+ "wietsedv/bert-base-dutch-cased",
62
+ # See all BERT models at https://huggingface.co/models?filter=bert
63
+ ]
64
+
65
+
66
+ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
67
+ """ Load tf checkpoints in a pytorch model.
68
+ """
69
+ try:
70
+ import re
71
+ import numpy as np
72
+ import tensorflow as tf
73
+ except ImportError:
74
+ logger.error(
75
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
76
+ "https://www.tensorflow.org/install/ for installation instructions."
77
+ )
78
+ raise
79
+ tf_path = os.path.abspath(tf_checkpoint_path)
80
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
81
+ # Load weights from TF model
82
+ init_vars = tf.train.list_variables(tf_path)
83
+ names = []
84
+ arrays = []
85
+ for name, shape in init_vars:
86
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
87
+ array = tf.train.load_variable(tf_path, name)
88
+ names.append(name)
89
+ arrays.append(array)
90
+
91
+ for name, array in zip(names, arrays):
92
+ name = name.split("/")
93
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
94
+ # which are not required for using pretrained model
95
+ if any(
96
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
97
+ for n in name
98
+ ):
99
+ logger.info("Skipping {}".format("/".join(name)))
100
+ continue
101
+ pointer = model
102
+ for m_name in name:
103
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
104
+ scope_names = re.split(r"_(\d+)", m_name)
105
+ else:
106
+ scope_names = [m_name]
107
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
108
+ pointer = getattr(pointer, "weight")
109
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
110
+ pointer = getattr(pointer, "bias")
111
+ elif scope_names[0] == "output_weights":
112
+ pointer = getattr(pointer, "weight")
113
+ elif scope_names[0] == "squad":
114
+ pointer = getattr(pointer, "classifier")
115
+ else:
116
+ try:
117
+ pointer = getattr(pointer, scope_names[0])
118
+ except AttributeError:
119
+ logger.info("Skipping {}".format("/".join(name)))
120
+ continue
121
+ if len(scope_names) >= 2:
122
+ num = int(scope_names[1])
123
+ pointer = pointer[num]
124
+ if m_name[-11:] == "_embeddings":
125
+ pointer = getattr(pointer, "weight")
126
+ elif m_name == "kernel":
127
+ array = np.transpose(array)
128
+ try:
129
+ assert pointer.shape == array.shape
130
+ except AssertionError as e:
131
+ e.args += (pointer.shape, array.shape)
132
+ raise
133
+ logger.info("Initialize PyTorch weight {}".format(name))
134
+ pointer.data = torch.from_numpy(array)
135
+ return model
136
+
137
+
138
+ def mish(x):
139
+ return x * torch.tanh(nn.functional.softplus(x))
140
+
141
+
142
+ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
143
+
144
+
145
+ BertLayerNorm = torch.nn.LayerNorm
146
+
147
+
148
+ class BertEmbeddings(nn.Module):
149
+ """Construct the embeddings from word, position and token_type embeddings.
150
+ """
151
+
152
+ def __init__(self, config):
153
+ super().__init__()
154
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
155
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
156
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
157
+
158
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
159
+ # any TensorFlow checkpoint file
160
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
161
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
162
+
163
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
164
+ if input_ids is not None:
165
+ input_shape = input_ids.size()
166
+ else:
167
+ input_shape = inputs_embeds.size()[:-1]
168
+
169
+ seq_length = input_shape[1]
170
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
171
+ if position_ids is None:
172
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
173
+ position_ids = position_ids.unsqueeze(0).expand(input_shape)
174
+ if token_type_ids is None:
175
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
176
+
177
+ if inputs_embeds is None:
178
+ inputs_embeds = self.word_embeddings(input_ids)
179
+ position_embeddings = self.position_embeddings(position_ids)
180
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
181
+
182
+ embeddings = inputs_embeds + position_embeddings + token_type_embeddings
183
+ embeddings = self.LayerNorm(embeddings)
184
+ embeddings = self.dropout(embeddings)
185
+ return embeddings
186
+
187
+
188
+ class BertSelfAttention(nn.Module):
189
+ def __init__(self, config):
190
+ super().__init__()
191
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
192
+ raise ValueError(
193
+ "The hidden size (%d) is not a multiple of the number of attention "
194
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
195
+ )
196
+
197
+ self.num_attention_heads = config.num_attention_heads
198
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
199
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
200
+
201
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
202
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
203
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
204
+
205
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
206
+
207
+ def transpose_for_scores(self, x):
208
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
209
+ x = x.view(*new_x_shape)
210
+ return x.permute(0, 2, 1, 3)
211
+
212
+ def forward(
213
+ self,
214
+ hidden_states,
215
+ attention_mask=None,
216
+ head_mask=None,
217
+ encoder_hidden_states=None,
218
+ encoder_attention_mask=None,
219
+ output_attentions=False,
220
+ ):
221
+ mixed_query_layer = self.query(hidden_states)
222
+
223
+ # If this is instantiated as a cross-attention module, the keys
224
+ # and values come from an encoder; the attention mask needs to be
225
+ # such that the encoder's padding tokens are not attended to.
226
+ if encoder_hidden_states is not None:
227
+ mixed_key_layer = self.key(encoder_hidden_states)
228
+ mixed_value_layer = self.value(encoder_hidden_states)
229
+ attention_mask = encoder_attention_mask
230
+ else:
231
+ mixed_key_layer = self.key(hidden_states)
232
+ mixed_value_layer = self.value(hidden_states)
233
+
234
+ query_layer = self.transpose_for_scores(mixed_query_layer)
235
+ key_layer = self.transpose_for_scores(mixed_key_layer)
236
+ value_layer = self.transpose_for_scores(mixed_value_layer)
237
+
238
+ # Take the dot product between "query" and "key" to get the raw attention scores.
239
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
240
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
241
+ if attention_mask is not None:
242
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
243
+ attention_scores = attention_scores + attention_mask
244
+
245
+ # Normalize the attention scores to probabilities.
246
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
247
+
248
+ # This is actually dropping out entire tokens to attend to, which might
249
+ # seem a bit unusual, but is taken from the original Transformer paper.
250
+ attention_probs = self.dropout(attention_probs)
251
+
252
+ # Mask heads if we want to
253
+ if head_mask is not None:
254
+ attention_probs = attention_probs * head_mask
255
+
256
+ context_layer = torch.matmul(attention_probs, value_layer)
257
+
258
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
259
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
260
+ context_layer = context_layer.view(*new_context_layer_shape)
261
+
262
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
263
+ return outputs
264
+
265
+
266
+ class BertSelfOutput(nn.Module):
267
+ def __init__(self, config):
268
+ super().__init__()
269
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
270
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
271
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
272
+
273
+ def forward(self, hidden_states, input_tensor):
274
+ hidden_states = self.dense(hidden_states)
275
+ hidden_states = self.dropout(hidden_states)
276
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
277
+ return hidden_states
278
+
279
+
280
+ class BertAttention(nn.Module):
281
+ def __init__(self, config):
282
+ super().__init__()
283
+ self.self = BertSelfAttention(config)
284
+ self.output = BertSelfOutput(config)
285
+ self.pruned_heads = set()
286
+
287
+ def prune_heads(self, heads):
288
+ if len(heads) == 0:
289
+ return
290
+ heads, index = find_pruneable_heads_and_indices(
291
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
292
+ )
293
+
294
+ # Prune linear layers
295
+ self.self.query = prune_linear_layer(self.self.query, index)
296
+ self.self.key = prune_linear_layer(self.self.key, index)
297
+ self.self.value = prune_linear_layer(self.self.value, index)
298
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
299
+
300
+ # Update hyper params and store pruned heads
301
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
302
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
303
+ self.pruned_heads = self.pruned_heads.union(heads)
304
+
305
+ def forward(
306
+ self,
307
+ hidden_states,
308
+ attention_mask=None,
309
+ head_mask=None,
310
+ encoder_hidden_states=None,
311
+ encoder_attention_mask=None,
312
+ output_attentions=False,
313
+ ):
314
+ self_outputs = self.self(
315
+ hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions,
316
+ )
317
+ attention_output = self.output(self_outputs[0], hidden_states)
318
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
319
+ return outputs
320
+
321
+
322
+ class BertIntermediate(nn.Module):
323
+ def __init__(self, config):
324
+ super().__init__()
325
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
326
+ if isinstance(config.hidden_act, str):
327
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
328
+ else:
329
+ self.intermediate_act_fn = config.hidden_act
330
+
331
+ def forward(self, hidden_states):
332
+ hidden_states = self.dense(hidden_states)
333
+ hidden_states = self.intermediate_act_fn(hidden_states)
334
+ return hidden_states
335
+
336
+
337
+ class BertOutput(nn.Module):
338
+ def __init__(self, config):
339
+ super().__init__()
340
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
341
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
342
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
343
+
344
+ def forward(self, hidden_states, input_tensor):
345
+ hidden_states = self.dense(hidden_states)
346
+ hidden_states = self.dropout(hidden_states)
347
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
348
+ return hidden_states
349
+
350
+
351
+ class BertLayer(nn.Module):
352
+ def __init__(self, config):
353
+ super().__init__()
354
+ self.attention = BertAttention(config)
355
+ self.is_decoder = config.is_decoder
356
+ if self.is_decoder:
357
+ self.crossattention = BertAttention(config)
358
+ self.intermediate = BertIntermediate(config)
359
+ self.output = BertOutput(config)
360
+
361
+ def forward(
362
+ self,
363
+ hidden_states,
364
+ attention_mask=None,
365
+ head_mask=None,
366
+ encoder_hidden_states=None,
367
+ encoder_attention_mask=None,
368
+ output_attentions=False,
369
+ ):
370
+ self_attention_outputs = self.attention(
371
+ hidden_states, attention_mask, head_mask, output_attentions=output_attentions,
372
+ )
373
+ attention_output = self_attention_outputs[0]
374
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
375
+
376
+ if self.is_decoder and encoder_hidden_states is not None:
377
+ cross_attention_outputs = self.crossattention(
378
+ attention_output,
379
+ attention_mask,
380
+ head_mask,
381
+ encoder_hidden_states,
382
+ encoder_attention_mask,
383
+ output_attentions,
384
+ )
385
+ attention_output = cross_attention_outputs[0]
386
+ outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
387
+
388
+ intermediate_output = self.intermediate(attention_output)
389
+ layer_output = self.output(intermediate_output, attention_output)
390
+ outputs = (layer_output,) + outputs
391
+ return outputs
392
+
393
+
394
+ class BertEncoder(nn.Module):
395
+ def __init__(self, config):
396
+ super().__init__()
397
+ self.config = config
398
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
399
+
400
+ def forward(
401
+ self,
402
+ hidden_states,
403
+ attention_mask=None,
404
+ head_mask=None,
405
+ encoder_hidden_states=None,
406
+ encoder_attention_mask=None,
407
+ output_attentions=False,
408
+ output_hidden_states=False,
409
+ ):
410
+ all_hidden_states = ()
411
+ all_attentions = ()
412
+ for i, layer_module in enumerate(self.layer):
413
+ if output_hidden_states:
414
+ all_hidden_states = all_hidden_states + (hidden_states,)
415
+
416
+ if getattr(self.config, "gradient_checkpointing", False):
417
+
418
+ def create_custom_forward(module):
419
+ def custom_forward(*inputs):
420
+ return module(*inputs, output_attentions)
421
+
422
+ return custom_forward
423
+
424
+ layer_outputs = torch.utils.checkpoint.checkpoint(
425
+ create_custom_forward(layer_module),
426
+ hidden_states,
427
+ attention_mask,
428
+ head_mask[i],
429
+ encoder_hidden_states,
430
+ encoder_attention_mask,
431
+ )
432
+ else:
433
+ layer_outputs = layer_module(
434
+ hidden_states,
435
+ attention_mask,
436
+ head_mask[i],
437
+ encoder_hidden_states,
438
+ encoder_attention_mask,
439
+ output_attentions,
440
+ )
441
+ hidden_states = layer_outputs[0]
442
+
443
+ if output_attentions:
444
+ all_attentions = all_attentions + (layer_outputs[1],)
445
+
446
+ # Add last layer
447
+ if output_hidden_states:
448
+ all_hidden_states = all_hidden_states + (hidden_states,)
449
+
450
+ outputs = (hidden_states,)
451
+ if output_hidden_states:
452
+ outputs = outputs + (all_hidden_states,)
453
+ if output_attentions:
454
+ outputs = outputs + (all_attentions,)
455
+ return outputs # last-layer hidden state, (all hidden states), (all attentions)
456
+
457
+
458
+ class BertPooler(nn.Module):
459
+ def __init__(self, config):
460
+ super().__init__()
461
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
462
+ self.activation = nn.Tanh()
463
+
464
+ def forward(self, hidden_states):
465
+ # We "pool" the model by simply taking the hidden state corresponding
466
+ # to the first token.
467
+ first_token_tensor = hidden_states[:, 0]
468
+ pooled_output = self.dense(first_token_tensor)
469
+ pooled_output = self.activation(pooled_output)
470
+ return pooled_output
471
+
472
+
473
+ class BertPredictionHeadTransform(nn.Module):
474
+ def __init__(self, config):
475
+ super().__init__()
476
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
477
+ if isinstance(config.hidden_act, str):
478
+ self.transform_act_fn = ACT2FN[config.hidden_act]
479
+ else:
480
+ self.transform_act_fn = config.hidden_act
481
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
482
+
483
+ def forward(self, hidden_states):
484
+ hidden_states = self.dense(hidden_states)
485
+ hidden_states = self.transform_act_fn(hidden_states)
486
+ hidden_states = self.LayerNorm(hidden_states)
487
+ return hidden_states
488
+
489
+
490
+ class BertLMPredictionHead(nn.Module):
491
+ def __init__(self, config):
492
+ super().__init__()
493
+ self.transform = BertPredictionHeadTransform(config)
494
+
495
+ # The output weights are the same as the input embeddings, but there is
496
+ # an output-only bias for each token.
497
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
498
+
499
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
500
+
501
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
502
+ self.decoder.bias = self.bias
503
+
504
+ def forward(self, hidden_states):
505
+ hidden_states = self.transform(hidden_states)
506
+ hidden_states = self.decoder(hidden_states)
507
+ return hidden_states
508
+
509
+
510
+ class BertOnlyMLMHead(nn.Module):
511
+ def __init__(self, config):
512
+ super().__init__()
513
+ self.predictions = BertLMPredictionHead(config)
514
+
515
+ def forward(self, sequence_output):
516
+ prediction_scores = self.predictions(sequence_output)
517
+ return prediction_scores
518
+
519
+
520
+ class BertOnlyNSPHead(nn.Module):
521
+ def __init__(self, config):
522
+ super().__init__()
523
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
524
+
525
+ def forward(self, pooled_output):
526
+ seq_relationship_score = self.seq_relationship(pooled_output)
527
+ return seq_relationship_score
528
+
529
+
530
+ class BertPreTrainingHeads(nn.Module):
531
+ def __init__(self, config):
532
+ super().__init__()
533
+ self.predictions = BertLMPredictionHead(config)
534
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
535
+
536
+ def forward(self, sequence_output, pooled_output):
537
+ prediction_scores = self.predictions(sequence_output)
538
+ seq_relationship_score = self.seq_relationship(pooled_output)
539
+ return prediction_scores, seq_relationship_score
540
+
541
+
542
+ class BertPreTrainedModel(PreTrainedModel):
543
+ """ An abstract class to handle weights initialization and
544
+ a simple interface for downloading and loading pretrained models.
545
+ """
546
+
547
+ config_class = BertConfig
548
+ load_tf_weights = load_tf_weights_in_bert
549
+ base_model_prefix = "bert"
550
+
551
+ def _init_weights(self, module):
552
+ """ Initialize the weights """
553
+ if isinstance(module, (nn.Linear, nn.Embedding)):
554
+ # Slightly different from the TF version which uses truncated_normal for initialization
555
+ # cf https://github.com/pytorch/pytorch/pull/5617
556
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
557
+ elif isinstance(module, BertLayerNorm):
558
+ module.bias.data.zero_()
559
+ module.weight.data.fill_(1.0)
560
+ if isinstance(module, nn.Linear) and module.bias is not None:
561
+ module.bias.data.zero_()
562
+
563
+
564
+ BERT_START_DOCSTRING = r"""
565
+ This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
566
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
567
+ usage and behavior.
568
+
569
+ Parameters:
570
+ config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
571
+ Initializing with a config file does not load the weights associated with the model, only the configuration.
572
+ Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
573
+ """
574
+
575
+ BERT_INPUTS_DOCSTRING = r"""
576
+ Args:
577
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
578
+ Indices of input sequence tokens in the vocabulary.
579
+
580
+ Indices can be obtained using :class:`transformers.BertTokenizer`.
581
+ See :func:`transformers.PreTrainedTokenizer.encode` and
582
+ :func:`transformers.PreTrainedTokenizer.__call__` for details.
583
+
584
+ `What are input IDs? <../glossary.html#input-ids>`__
585
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
586
+ Mask to avoid performing attention on padding token indices.
587
+ Mask values selected in ``[0, 1]``:
588
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
589
+
590
+ `What are attention masks? <../glossary.html#attention-mask>`__
591
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
592
+ Segment token indices to indicate first and second portions of the inputs.
593
+ Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
594
+ corresponds to a `sentence B` token
595
+
596
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
597
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
598
+ Indices of positions of each input sequence tokens in the position embeddings.
599
+ Selected in the range ``[0, config.max_position_embeddings - 1]``.
600
+
601
+ `What are position IDs? <../glossary.html#position-ids>`_
602
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
603
+ Mask to nullify selected heads of the self-attention modules.
604
+ Mask values selected in ``[0, 1]``:
605
+ :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
606
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
607
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
608
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
609
+ than the model's internal embedding lookup matrix.
610
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
611
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
612
+ if the model is configured as a decoder.
613
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
614
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask
615
+ is used in the cross-attention if the model is configured as a decoder.
616
+ Mask values selected in ``[0, 1]``:
617
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
618
+ output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
619
+ If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
620
+ """
621
+
622
+
623
+ @add_start_docstrings(
624
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
625
+ BERT_START_DOCSTRING,
626
+ )
627
+ class BertModel(BertPreTrainedModel):
628
+ """
629
+
630
+ The model can behave as an encoder (with only self-attention) as well
631
+ as a decoder, in which case a layer of cross-attention is added between
632
+ the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
633
+ Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
634
+
635
+ To behave as an decoder the model needs to be initialized with the
636
+ :obj:`is_decoder` argument of the configuration set to :obj:`True`; an
637
+ :obj:`encoder_hidden_states` is expected as an input to the forward pass.
638
+
639
+ .. _`Attention is all you need`:
640
+ https://arxiv.org/abs/1706.03762
641
+
642
+ """
643
+
644
+ def __init__(self, config):
645
+ super().__init__(config)
646
+ self.config = config
647
+
648
+ self.embeddings = BertEmbeddings(config)
649
+ self.encoder = BertEncoder(config)
650
+ self.pooler = BertPooler(config)
651
+
652
+ self.init_weights()
653
+
654
+ def get_input_embeddings(self):
655
+ return self.embeddings.word_embeddings
656
+
657
+ def set_input_embeddings(self, value):
658
+ self.embeddings.word_embeddings = value
659
+
660
+ def _prune_heads(self, heads_to_prune):
661
+ """ Prunes heads of the model.
662
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
663
+ See base class PreTrainedModel
664
+ """
665
+ for layer, heads in heads_to_prune.items():
666
+ self.encoder.layer[layer].attention.prune_heads(heads)
667
+
668
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
669
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
670
+ def forward(
671
+ self,
672
+ input_ids=None,
673
+ attention_mask=None,
674
+ token_type_ids=None,
675
+ position_ids=None,
676
+ head_mask=None,
677
+ inputs_embeds=None,
678
+ encoder_hidden_states=None,
679
+ encoder_attention_mask=None,
680
+ output_attentions=None,
681
+ output_hidden_states=None,
682
+ ):
683
+ r"""
684
+ Return:
685
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
686
+ last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
687
+ Sequence of hidden-states at the output of the last layer of the model.
688
+ pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
689
+ Last layer hidden-state of the first token of the sequence (classification token)
690
+ further processed by a Linear layer and a Tanh activation function. The Linear
691
+ layer weights are trained from the next sentence prediction (classification)
692
+ objective during pre-training.
693
+
694
+ This output is usually *not* a good summary
695
+ of the semantic content of the input, you're often better with averaging or pooling
696
+ the sequence of hidden-states for the whole input sequence.
697
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
698
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
699
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
700
+
701
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
702
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
703
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
704
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
705
+
706
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
707
+ heads.
708
+ """
709
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
710
+ output_hidden_states = (
711
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
712
+ )
713
+
714
+ if input_ids is not None and inputs_embeds is not None:
715
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
716
+ elif input_ids is not None:
717
+ input_shape = input_ids.size()
718
+ elif inputs_embeds is not None:
719
+ input_shape = inputs_embeds.size()[:-1]
720
+ else:
721
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
722
+
723
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
724
+
725
+ if attention_mask is None:
726
+ attention_mask = torch.ones(input_shape, device=device)
727
+ if token_type_ids is None:
728
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
729
+
730
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
731
+ # ourselves in which case we just need to make it broadcastable to all heads.
732
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
733
+
734
+ # If a 2D ou 3D attention mask is provided for the cross-attention
735
+ # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
736
+ if self.config.is_decoder and encoder_hidden_states is not None:
737
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
738
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
739
+ if encoder_attention_mask is None:
740
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
741
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
742
+ else:
743
+ encoder_extended_attention_mask = None
744
+
745
+ # Prepare head mask if needed
746
+ # 1.0 in head_mask indicate we keep the head
747
+ # attention_probs has shape bsz x n_heads x N x N
748
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
749
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
750
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
751
+
752
+ embedding_output = self.embeddings(
753
+ input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
754
+ )
755
+ encoder_outputs = self.encoder(
756
+ embedding_output,
757
+ attention_mask=extended_attention_mask,
758
+ head_mask=head_mask,
759
+ encoder_hidden_states=encoder_hidden_states,
760
+ encoder_attention_mask=encoder_extended_attention_mask,
761
+ output_attentions=output_attentions,
762
+ output_hidden_states=output_hidden_states,
763
+ )
764
+ sequence_output = encoder_outputs[0]
765
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
766
+
767
+ outputs = (sequence_output, pooled_output,) + encoder_outputs[
768
+ 1:
769
+ ] # add hidden_states and attentions if they are here
770
+ return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
771
+
772
+
773
+ @add_start_docstrings(
774
+ """Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and
775
+ a `next sentence prediction (classification)` head. """,
776
+ BERT_START_DOCSTRING,
777
+ )
778
+ class BertForPreTraining(BertPreTrainedModel):
779
+ def __init__(self, config):
780
+ super().__init__(config)
781
+
782
+ self.bert = BertModel(config)
783
+ self.cls = BertPreTrainingHeads(config)
784
+
785
+ self.init_weights()
786
+
787
+ def get_output_embeddings(self):
788
+ return self.cls.predictions.decoder
789
+
790
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
791
+ def forward(
792
+ self,
793
+ input_ids=None,
794
+ attention_mask=None,
795
+ token_type_ids=None,
796
+ position_ids=None,
797
+ head_mask=None,
798
+ inputs_embeds=None,
799
+ labels=None,
800
+ next_sentence_label=None,
801
+ output_attentions=None,
802
+ output_hidden_states=None,
803
+ **kwargs
804
+ ):
805
+ r"""
806
+ labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
807
+ Labels for computing the masked language modeling loss.
808
+ Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
809
+ Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
810
+ in ``[0, ..., config.vocab_size]``
811
+ next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
812
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
813
+ Indices should be in ``[0, 1]``.
814
+ ``0`` indicates sequence B is a continuation of sequence A,
815
+ ``1`` indicates sequence B is a random sequence.
816
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
817
+ Used to hide legacy arguments that have been deprecated.
818
+
819
+ Returns:
820
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
821
+ loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
822
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
823
+ prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
824
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
825
+ seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
826
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False
827
+ continuation before SoftMax).
828
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
829
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
830
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
831
+
832
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
833
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
834
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
835
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
836
+
837
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
838
+ heads.
839
+
840
+
841
+ Examples::
842
+
843
+ >>> from transformers import BertTokenizer, BertForPreTraining
844
+ >>> import torch
845
+
846
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
847
+ >>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
848
+
849
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
850
+ >>> outputs = model(**inputs)
851
+
852
+ >>> prediction_scores, seq_relationship_scores = outputs[:2]
853
+
854
+ """
855
+ if "masked_lm_labels" in kwargs:
856
+ warnings.warn(
857
+ "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
858
+ DeprecationWarning,
859
+ )
860
+ labels = kwargs.pop("masked_lm_labels")
861
+ assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
862
+
863
+ outputs = self.bert(
864
+ input_ids,
865
+ attention_mask=attention_mask,
866
+ token_type_ids=token_type_ids,
867
+ position_ids=position_ids,
868
+ head_mask=head_mask,
869
+ inputs_embeds=inputs_embeds,
870
+ output_attentions=output_attentions,
871
+ output_hidden_states=output_hidden_states,
872
+ )
873
+
874
+ sequence_output, pooled_output = outputs[:2]
875
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
876
+
877
+ outputs = (prediction_scores, seq_relationship_score,) + outputs[
878
+ 2:
879
+ ] # add hidden states and attention if they are here
880
+
881
+ if labels is not None and next_sentence_label is not None:
882
+ loss_fct = CrossEntropyLoss()
883
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
884
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
885
+ total_loss = masked_lm_loss + next_sentence_loss
886
+ outputs = (total_loss,) + outputs
887
+
888
+ return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
889
+
890
+
891
+ @add_start_docstrings(
892
+ """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
893
+ )
894
+ class BertLMHeadModel(BertPreTrainedModel):
895
+ def __init__(self, config):
896
+ super().__init__(config)
897
+ assert config.is_decoder, "If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True`."
898
+
899
+ self.bert = BertModel(config)
900
+ self.cls = BertOnlyMLMHead(config)
901
+
902
+ self.init_weights()
903
+
904
+ def get_output_embeddings(self):
905
+ return self.cls.predictions.decoder
906
+
907
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
908
+ def forward(
909
+ self,
910
+ input_ids=None,
911
+ attention_mask=None,
912
+ token_type_ids=None,
913
+ position_ids=None,
914
+ head_mask=None,
915
+ inputs_embeds=None,
916
+ labels=None,
917
+ encoder_hidden_states=None,
918
+ encoder_attention_mask=None,
919
+ output_attentions=None,
920
+ output_hidden_states=None,
921
+ **kwargs
922
+ ):
923
+ r"""
924
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
925
+ Labels for computing the left-to-right language modeling loss (next word prediction).
926
+ Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
927
+ Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
928
+ in ``[0, ..., config.vocab_size]``
929
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
930
+ Used to hide legacy arguments that have been deprecated.
931
+
932
+ Returns:
933
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
934
+ ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
935
+ Next token prediction loss.
936
+ prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
937
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
938
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
939
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
940
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
941
+
942
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
943
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
944
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
945
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
946
+
947
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
948
+ heads.
949
+
950
+ Example::
951
+
952
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
953
+ >>> import torch
954
+
955
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
956
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
957
+ >>> config.is_decoder = True
958
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
959
+
960
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
961
+ >>> outputs = model(**inputs)
962
+
963
+ >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
964
+ """
965
+
966
+ outputs = self.bert(
967
+ input_ids,
968
+ attention_mask=attention_mask,
969
+ token_type_ids=token_type_ids,
970
+ position_ids=position_ids,
971
+ head_mask=head_mask,
972
+ inputs_embeds=inputs_embeds,
973
+ encoder_hidden_states=encoder_hidden_states,
974
+ encoder_attention_mask=encoder_attention_mask,
975
+ output_attentions=output_attentions,
976
+ output_hidden_states=output_hidden_states,
977
+ )
978
+
979
+ sequence_output = outputs[0]
980
+ prediction_scores = self.cls(sequence_output)
981
+
982
+ outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
983
+
984
+ if labels is not None:
985
+ # we are doing next-token prediction; shift prediction scores and input ids by one
986
+ prediction_scores = prediction_scores[:, :-1, :].contiguous()
987
+ labels = labels[:, 1:].contiguous()
988
+ loss_fct = CrossEntropyLoss()
989
+ ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
990
+ outputs = (ltr_lm_loss,) + outputs
991
+
992
+ return outputs # (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
993
+
994
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
995
+ input_shape = input_ids.shape
996
+
997
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
998
+ if attention_mask is None:
999
+ attention_mask = input_ids.new_ones(input_shape)
1000
+
1001
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1002
+
1003
+
1004
+ @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
1005
+ class BertForMaskedLM(BertPreTrainedModel):
1006
+ def __init__(self, config):
1007
+ super().__init__(config)
1008
+ assert (
1009
+ not config.is_decoder
1010
+ ), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
1011
+
1012
+ self.bert = BertModel(config)
1013
+ self.cls = BertOnlyMLMHead(config)
1014
+
1015
+ self.init_weights()
1016
+
1017
+ def get_output_embeddings(self):
1018
+ return self.cls.predictions.decoder
1019
+
1020
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1021
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1022
+ def forward(
1023
+ self,
1024
+ input_ids=None,
1025
+ attention_mask=None,
1026
+ token_type_ids=None,
1027
+ position_ids=None,
1028
+ head_mask=None,
1029
+ inputs_embeds=None,
1030
+ labels=None,
1031
+ encoder_hidden_states=None,
1032
+ encoder_attention_mask=None,
1033
+ output_attentions=None,
1034
+ output_hidden_states=None,
1035
+ **kwargs
1036
+ ):
1037
+ r"""
1038
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1039
+ Labels for computing the masked language modeling loss.
1040
+ Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
1041
+ Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
1042
+ in ``[0, ..., config.vocab_size]``
1043
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
1044
+ Used to hide legacy arguments that have been deprecated.
1045
+
1046
+ Returns:
1047
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1048
+ masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
1049
+ Masked language modeling loss.
1050
+ prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
1051
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1052
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1053
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1054
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1055
+
1056
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1057
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1058
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1059
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1060
+
1061
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1062
+ heads.
1063
+ """
1064
+ if "masked_lm_labels" in kwargs:
1065
+ warnings.warn(
1066
+ "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
1067
+ DeprecationWarning,
1068
+ )
1069
+ labels = kwargs.pop("masked_lm_labels")
1070
+ assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
1071
+ assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
1072
+
1073
+ outputs = self.bert(
1074
+ input_ids,
1075
+ attention_mask=attention_mask,
1076
+ token_type_ids=token_type_ids,
1077
+ position_ids=position_ids,
1078
+ head_mask=head_mask,
1079
+ inputs_embeds=inputs_embeds,
1080
+ encoder_hidden_states=encoder_hidden_states,
1081
+ encoder_attention_mask=encoder_attention_mask,
1082
+ output_attentions=output_attentions,
1083
+ output_hidden_states=output_hidden_states,
1084
+ )
1085
+
1086
+ sequence_output = outputs[0]
1087
+ prediction_scores = self.cls(sequence_output)
1088
+
1089
+ outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
1090
+
1091
+ if labels is not None:
1092
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1093
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1094
+ outputs = (masked_lm_loss,) + outputs
1095
+
1096
+ return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
1097
+
1098
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1099
+ input_shape = input_ids.shape
1100
+ effective_batch_size = input_shape[0]
1101
+
1102
+ # add a dummy token
1103
+ assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
1104
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1105
+ dummy_token = torch.full(
1106
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1107
+ )
1108
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1109
+
1110
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1111
+
1112
+
1113
+ @add_start_docstrings(
1114
+ """Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
1115
+ )
1116
+ class BertForNextSentencePrediction(BertPreTrainedModel):
1117
+ def __init__(self, config):
1118
+ super().__init__(config)
1119
+
1120
+ self.bert = BertModel(config)
1121
+ self.cls = BertOnlyNSPHead(config)
1122
+
1123
+ self.init_weights()
1124
+
1125
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1126
+ def forward(
1127
+ self,
1128
+ input_ids=None,
1129
+ attention_mask=None,
1130
+ token_type_ids=None,
1131
+ position_ids=None,
1132
+ head_mask=None,
1133
+ inputs_embeds=None,
1134
+ next_sentence_label=None,
1135
+ output_attentions=None,
1136
+ output_hidden_states=None,
1137
+ ):
1138
+ r"""
1139
+ next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1140
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
1141
+ Indices should be in ``[0, 1]``.
1142
+ ``0`` indicates sequence B is a continuation of sequence A,
1143
+ ``1`` indicates sequence B is a random sequence.
1144
+
1145
+ Returns:
1146
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1147
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
1148
+ Next sequence prediction (classification) loss.
1149
+ seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
1150
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
1151
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1152
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1153
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1154
+
1155
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1156
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1157
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1158
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1159
+
1160
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1161
+ heads.
1162
+
1163
+ Examples::
1164
+
1165
+ >>> from transformers import BertTokenizer, BertForNextSentencePrediction
1166
+ >>> import torch
1167
+
1168
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1169
+ >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
1170
+
1171
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1172
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1173
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
1174
+
1175
+ >>> loss, logits = model(**encoding, next_sentence_label=torch.LongTensor([1]))
1176
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1177
+ """
1178
+
1179
+ outputs = self.bert(
1180
+ input_ids,
1181
+ attention_mask=attention_mask,
1182
+ token_type_ids=token_type_ids,
1183
+ position_ids=position_ids,
1184
+ head_mask=head_mask,
1185
+ inputs_embeds=inputs_embeds,
1186
+ output_attentions=output_attentions,
1187
+ output_hidden_states=output_hidden_states,
1188
+ )
1189
+
1190
+ pooled_output = outputs[1]
1191
+
1192
+ seq_relationship_score = self.cls(pooled_output)
1193
+
1194
+ outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
1195
+ if next_sentence_label is not None:
1196
+ loss_fct = CrossEntropyLoss()
1197
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1198
+ outputs = (next_sentence_loss,) + outputs
1199
+
1200
+ return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
1201
+
1202
+
1203
+ @add_start_docstrings(
1204
+ """Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
1205
+ the pooled output) e.g. for GLUE tasks. """,
1206
+ BERT_START_DOCSTRING,
1207
+ )
1208
+ class BertForSequenceClassification(BertPreTrainedModel):
1209
+ def __init__(self, config):
1210
+ super().__init__(config)
1211
+ self.num_labels = config.num_labels
1212
+
1213
+ self.bert = BertModel(config)
1214
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1215
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1216
+
1217
+ self.init_weights()
1218
+
1219
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1220
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1221
+ def forward(
1222
+ self,
1223
+ input_ids=None,
1224
+ attention_mask=None,
1225
+ token_type_ids=None,
1226
+ position_ids=None,
1227
+ head_mask=None,
1228
+ inputs_embeds=None,
1229
+ labels=None,
1230
+ output_attentions=None,
1231
+ output_hidden_states=None,
1232
+ ):
1233
+ r"""
1234
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1235
+ Labels for computing the sequence classification/regression loss.
1236
+ Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
1237
+ If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1238
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1239
+
1240
+ Returns:
1241
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1242
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
1243
+ Classification (or regression if config.num_labels==1) loss.
1244
+ logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
1245
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
1246
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1247
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1248
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1249
+
1250
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1251
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1252
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1253
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1254
+
1255
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1256
+ heads.
1257
+ """
1258
+
1259
+ outputs = self.bert(
1260
+ input_ids,
1261
+ attention_mask=attention_mask,
1262
+ token_type_ids=token_type_ids,
1263
+ position_ids=position_ids,
1264
+ head_mask=head_mask,
1265
+ inputs_embeds=inputs_embeds,
1266
+ output_attentions=output_attentions,
1267
+ output_hidden_states=output_hidden_states,
1268
+ )
1269
+
1270
+ pooled_output = outputs[1]
1271
+
1272
+ pooled_output = self.dropout(pooled_output)
1273
+ logits = self.classifier(pooled_output)
1274
+
1275
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
1276
+
1277
+ if labels is not None:
1278
+ if self.num_labels == 1:
1279
+ # We are doing regression
1280
+ loss_fct = MSELoss()
1281
+ loss = loss_fct(logits.view(-1), labels.view(-1))
1282
+ else:
1283
+ loss_fct = CrossEntropyLoss()
1284
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1285
+ outputs = (loss,) + outputs
1286
+
1287
+ return outputs # (loss), logits, (hidden_states), (attentions)
1288
+
1289
+
1290
+ @add_start_docstrings(
1291
+ """Bert Model with a multiple choice classification head on top (a linear layer on top of
1292
+ the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
1293
+ BERT_START_DOCSTRING,
1294
+ )
1295
+ class BertForMultipleChoice(BertPreTrainedModel):
1296
+ def __init__(self, config):
1297
+ super().__init__(config)
1298
+
1299
+ self.bert = BertModel(config)
1300
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1301
+ self.classifier = nn.Linear(config.hidden_size, 1)
1302
+
1303
+ self.init_weights()
1304
+
1305
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
1306
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1307
+ def forward(
1308
+ self,
1309
+ input_ids=None,
1310
+ attention_mask=None,
1311
+ token_type_ids=None,
1312
+ position_ids=None,
1313
+ head_mask=None,
1314
+ inputs_embeds=None,
1315
+ labels=None,
1316
+ output_attentions=None,
1317
+ output_hidden_states=None,
1318
+ ):
1319
+ r"""
1320
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1321
+ Labels for computing the multiple choice classification loss.
1322
+ Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
1323
+ of the input tensors. (see `input_ids` above)
1324
+
1325
+ Returns:
1326
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1327
+ loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
1328
+ Classification loss.
1329
+ classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
1330
+ `num_choices` is the second dimension of the input tensors. (see `input_ids` above).
1331
+
1332
+ Classification scores (before SoftMax).
1333
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1334
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1335
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1336
+
1337
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1338
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1339
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1340
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1341
+
1342
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1343
+ heads.
1344
+ """
1345
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1346
+
1347
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1348
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1349
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1350
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1351
+ inputs_embeds = (
1352
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1353
+ if inputs_embeds is not None
1354
+ else None
1355
+ )
1356
+
1357
+ outputs = self.bert(
1358
+ input_ids,
1359
+ attention_mask=attention_mask,
1360
+ token_type_ids=token_type_ids,
1361
+ position_ids=position_ids,
1362
+ head_mask=head_mask,
1363
+ inputs_embeds=inputs_embeds,
1364
+ output_attentions=output_attentions,
1365
+ output_hidden_states=output_hidden_states,
1366
+ )
1367
+
1368
+ pooled_output = outputs[1]
1369
+
1370
+ pooled_output = self.dropout(pooled_output)
1371
+ logits = self.classifier(pooled_output)
1372
+ reshaped_logits = logits.view(-1, num_choices)
1373
+
1374
+ outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
1375
+
1376
+ if labels is not None:
1377
+ loss_fct = CrossEntropyLoss()
1378
+ loss = loss_fct(reshaped_logits, labels)
1379
+ outputs = (loss,) + outputs
1380
+
1381
+ return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
1382
+
1383
+
1384
+ @add_start_docstrings(
1385
+ """Bert Model with a token classification head on top (a linear layer on top of
1386
+ the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
1387
+ BERT_START_DOCSTRING,
1388
+ )
1389
+ class BertForTokenClassification(BertPreTrainedModel):
1390
+ def __init__(self, config):
1391
+ super().__init__(config)
1392
+ self.num_labels = config.num_labels
1393
+
1394
+ self.bert = BertModel(config)
1395
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1396
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1397
+
1398
+ self.init_weights()
1399
+
1400
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1401
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1402
+ def forward(
1403
+ self,
1404
+ input_ids=None,
1405
+ attention_mask=None,
1406
+ token_type_ids=None,
1407
+ position_ids=None,
1408
+ head_mask=None,
1409
+ inputs_embeds=None,
1410
+ labels=None,
1411
+ output_attentions=None,
1412
+ output_hidden_states=None,
1413
+ ):
1414
+ r"""
1415
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1416
+ Labels for computing the token classification loss.
1417
+ Indices should be in ``[0, ..., config.num_labels - 1]``.
1418
+
1419
+ Returns:
1420
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1421
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
1422
+ Classification loss.
1423
+ scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
1424
+ Classification scores (before SoftMax).
1425
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1426
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1427
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1428
+
1429
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1430
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1431
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1432
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1433
+
1434
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1435
+ heads.
1436
+ """
1437
+
1438
+ outputs = self.bert(
1439
+ input_ids,
1440
+ attention_mask=attention_mask,
1441
+ token_type_ids=token_type_ids,
1442
+ position_ids=position_ids,
1443
+ head_mask=head_mask,
1444
+ inputs_embeds=inputs_embeds,
1445
+ output_attentions=output_attentions,
1446
+ output_hidden_states=output_hidden_states,
1447
+ )
1448
+
1449
+ sequence_output = outputs[0]
1450
+
1451
+ sequence_output = self.dropout(sequence_output)
1452
+ logits = self.classifier(sequence_output)
1453
+
1454
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
1455
+ if labels is not None:
1456
+ loss_fct = CrossEntropyLoss()
1457
+ # Only keep active parts of the loss
1458
+ if attention_mask is not None:
1459
+ active_loss = attention_mask.view(-1) == 1
1460
+ active_logits = logits.view(-1, self.num_labels)
1461
+ active_labels = torch.where(
1462
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
1463
+ )
1464
+ loss = loss_fct(active_logits, active_labels)
1465
+ else:
1466
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1467
+ outputs = (loss,) + outputs
1468
+
1469
+ return outputs # (loss), scores, (hidden_states), (attentions)
1470
+
1471
+
1472
+ @add_start_docstrings(
1473
+ """Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1474
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
1475
+ BERT_START_DOCSTRING,
1476
+ )
1477
+ class BertForQuestionAnswering(BertPreTrainedModel):
1478
+ def __init__(self, config):
1479
+ super().__init__(config)
1480
+ self.num_labels = config.num_labels
1481
+
1482
+ self.bert = BertModel(config)
1483
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1484
+
1485
+ self.init_weights()
1486
+
1487
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1488
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1489
+ def forward(
1490
+ self,
1491
+ input_ids=None,
1492
+ attention_mask=None,
1493
+ token_type_ids=None,
1494
+ position_ids=None,
1495
+ head_mask=None,
1496
+ inputs_embeds=None,
1497
+ start_positions=None,
1498
+ end_positions=None,
1499
+ output_attentions=None,
1500
+ output_hidden_states=None,
1501
+ ):
1502
+ r"""
1503
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1504
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1505
+ Positions are clamped to the length of the sequence (`sequence_length`).
1506
+ Position outside of the sequence are not taken into account for computing the loss.
1507
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1508
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1509
+ Positions are clamped to the length of the sequence (`sequence_length`).
1510
+ Position outside of the sequence are not taken into account for computing the loss.
1511
+
1512
+ Returns:
1513
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1514
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
1515
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
1516
+ start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
1517
+ Span-start scores (before SoftMax).
1518
+ end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
1519
+ Span-end scores (before SoftMax).
1520
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1521
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1522
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1523
+
1524
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1525
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1526
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1527
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1528
+
1529
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1530
+ heads.
1531
+ """
1532
+
1533
+ outputs = self.bert(
1534
+ input_ids,
1535
+ attention_mask=attention_mask,
1536
+ token_type_ids=token_type_ids,
1537
+ position_ids=position_ids,
1538
+ head_mask=head_mask,
1539
+ inputs_embeds=inputs_embeds,
1540
+ output_attentions=output_attentions,
1541
+ output_hidden_states=output_hidden_states,
1542
+ )
1543
+
1544
+ sequence_output = outputs[0]
1545
+
1546
+ logits = self.qa_outputs(sequence_output)
1547
+ start_logits, end_logits = logits.split(1, dim=-1)
1548
+ start_logits = start_logits.squeeze(-1)
1549
+ end_logits = end_logits.squeeze(-1)
1550
+
1551
+ outputs = (start_logits, end_logits,) + outputs[2:]
1552
+ if start_positions is not None and end_positions is not None:
1553
+ # If we are on multi-GPU, split add a dimension
1554
+ if len(start_positions.size()) > 1:
1555
+ start_positions = start_positions.squeeze(-1)
1556
+ if len(end_positions.size()) > 1:
1557
+ end_positions = end_positions.squeeze(-1)
1558
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1559
+ ignored_index = start_logits.size(1)
1560
+ start_positions.clamp_(0, ignored_index)
1561
+ end_positions.clamp_(0, ignored_index)
1562
+
1563
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1564
+ start_loss = loss_fct(start_logits, start_positions)
1565
+ end_loss = loss_fct(end_logits, end_positions)
1566
+ total_loss = (start_loss + end_loss) / 2
1567
+ outputs = (total_loss,) + outputs
1568
+
1569
+ return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
LAVT-RIS/bert/modeling_utils.py ADDED
@@ -0,0 +1,1268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ import logging
19
+ import os
20
+ from typing import Callable, Dict, List, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import Tensor, device, dtype, nn
24
+ from torch.nn import CrossEntropyLoss
25
+ from torch.nn import functional as F
26
+
27
+ from .activations import get_activation
28
+ from .configuration_utils import PretrainedConfig
29
+ from .file_utils import (
30
+ DUMMY_INPUTS,
31
+ TF2_WEIGHTS_NAME,
32
+ TF_WEIGHTS_NAME,
33
+ WEIGHTS_NAME,
34
+ cached_path,
35
+ hf_bucket_url,
36
+ is_remote_url,
37
+ )
38
+ from .generation_utils import GenerationMixin
39
+
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ try:
45
+ from torch.nn import Identity
46
+ except ImportError:
47
+ # Older PyTorch compatibility
48
+ class Identity(nn.Module):
49
+ r"""A placeholder identity operator that is argument-insensitive.
50
+ """
51
+
52
+ def __init__(self, *args, **kwargs):
53
+ super().__init__()
54
+
55
+ def forward(self, input):
56
+ return input
57
+
58
+
59
+ def find_pruneable_heads_and_indices(
60
+ heads: List, n_heads: int, head_size: int, already_pruned_heads: set
61
+ ) -> Tuple[set, "torch.LongTensor"]:
62
+ mask = torch.ones(n_heads, head_size)
63
+ heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
64
+ for head in heads:
65
+ # Compute how many pruned heads are before the head and move the index accordingly
66
+ head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
67
+ mask[head] = 0
68
+ mask = mask.view(-1).contiguous().eq(1)
69
+ index: torch.LongTensor = torch.arange(len(mask))[mask].long()
70
+ return heads, index
71
+
72
+
73
+ class ModuleUtilsMixin:
74
+ """
75
+ A few utilities for torch.nn.Modules, to be used as a mixin.
76
+ """
77
+
78
+ def num_parameters(self, only_trainable: bool = False) -> int:
79
+ """
80
+ Get number of (optionally, trainable) parameters in the module.
81
+ """
82
+ params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
83
+ return sum(p.numel() for p in params)
84
+
85
+ @staticmethod
86
+ def _hook_rss_memory_pre_forward(module, *args, **kwargs):
87
+ try:
88
+ import psutil
89
+ except (ImportError):
90
+ raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
91
+
92
+ process = psutil.Process(os.getpid())
93
+ mem = process.memory_info()
94
+ module.mem_rss_pre_forward = mem.rss
95
+ return None
96
+
97
+ @staticmethod
98
+ def _hook_rss_memory_post_forward(module, *args, **kwargs):
99
+ try:
100
+ import psutil
101
+ except (ImportError):
102
+ raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
103
+
104
+ process = psutil.Process(os.getpid())
105
+ mem = process.memory_info()
106
+ module.mem_rss_post_forward = mem.rss
107
+ mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
108
+ module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
109
+ return None
110
+
111
+ def add_memory_hooks(self):
112
+ """ Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
113
+ Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()`
114
+ """
115
+ for module in self.modules():
116
+ module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
117
+ module.register_forward_hook(self._hook_rss_memory_post_forward)
118
+ self.reset_memory_hooks_state()
119
+
120
+ def reset_memory_hooks_state(self):
121
+ for module in self.modules():
122
+ module.mem_rss_diff = 0
123
+ module.mem_rss_post_forward = 0
124
+ module.mem_rss_pre_forward = 0
125
+
126
+ @property
127
+ def device(self) -> device:
128
+ """
129
+ Get torch.device from module, assuming that the whole module has one device.
130
+ """
131
+ try:
132
+ return next(self.parameters()).device
133
+ except StopIteration:
134
+ # For nn.DataParallel compatibility in PyTorch 1.5
135
+
136
+ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
137
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
138
+ return tuples
139
+
140
+ gen = self._named_members(get_members_fn=find_tensor_attributes)
141
+ first_tuple = next(gen)
142
+ return first_tuple[1].device
143
+
144
+ @property
145
+ def dtype(self) -> dtype:
146
+ """
147
+ Get torch.dtype from module, assuming that the whole module has one dtype.
148
+ """
149
+ try:
150
+ return next(self.parameters()).dtype
151
+ except StopIteration:
152
+ # For nn.DataParallel compatibility in PyTorch 1.5
153
+
154
+ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
155
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
156
+ return tuples
157
+
158
+ gen = self._named_members(get_members_fn=find_tensor_attributes)
159
+ first_tuple = next(gen)
160
+ return first_tuple[1].dtype
161
+
162
+ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
163
+ """type: torch.Tensor -> torch.Tensor"""
164
+ if encoder_attention_mask.dim() == 3:
165
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
166
+ if encoder_attention_mask.dim() == 2:
167
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
168
+ # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
169
+ # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
170
+ # /transformer/transformer_layers.py#L270
171
+ # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
172
+ # encoder_extended_attention_mask.transpose(-1, -2))
173
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
174
+
175
+ if self.dtype == torch.float16:
176
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
177
+ elif self.dtype == torch.float32:
178
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
179
+ else:
180
+ raise ValueError(
181
+ "{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format(
182
+ self.dtype
183
+ )
184
+ )
185
+
186
+ return encoder_extended_attention_mask
187
+
188
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple, device: device) -> Tensor:
189
+ """Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored.
190
+
191
+ Arguments:
192
+ attention_mask: torch.Tensor with 1 indicating tokens to ATTEND to
193
+ input_shape: tuple, shape of input_ids
194
+ device: torch.Device, usually self.device
195
+
196
+ Returns:
197
+ torch.Tensor with dtype of attention_mask.dtype
198
+ """
199
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
200
+ # ourselves in which case we just need to make it broadcastable to all heads.
201
+ if attention_mask.dim() == 3:
202
+ extended_attention_mask = attention_mask[:, None, :, :]
203
+ elif attention_mask.dim() == 2:
204
+ # Provided a padding mask of dimensions [batch_size, seq_length]
205
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
206
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
207
+ if self.config.is_decoder:
208
+ batch_size, seq_length = input_shape
209
+ seq_ids = torch.arange(seq_length, device=device)
210
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
211
+ # causal and attention masks must have same type with pytorch version < 1.3
212
+ causal_mask = causal_mask.to(attention_mask.dtype)
213
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
214
+ else:
215
+ extended_attention_mask = attention_mask[:, None, None, :]
216
+ else:
217
+ raise ValueError(
218
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
219
+ input_shape, attention_mask.shape
220
+ )
221
+ )
222
+
223
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
224
+ # masked positions, this operation will create a tensor which is 0.0 for
225
+ # positions we want to attend and -10000.0 for masked positions.
226
+ # Since we are adding it to the raw scores before the softmax, this is
227
+ # effectively the same as removing these entirely.
228
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
229
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
230
+ return extended_attention_mask
231
+
232
+ def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor:
233
+ """
234
+ # Prepare head mask if needed
235
+ # 1.0 in head_mask indicate we keep the head
236
+ attention_probs has shape bsz x n_heads x N x N
237
+ Arguments:
238
+ head_mask: torch.Tensor or None: has shape [num_heads] or [num_hidden_layers x num_heads]
239
+ num_hidden_layers: int
240
+ Returns:
241
+ Tensor of shape shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
242
+ or list with [None] for each layer
243
+ """
244
+ if head_mask is not None:
245
+ head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
246
+ if is_attention_chunked is True:
247
+ head_mask = head_mask.unsqueeze(-1)
248
+ else:
249
+ head_mask = [None] * num_hidden_layers
250
+
251
+ return head_mask
252
+
253
+ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
254
+ """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
255
+ if head_mask.dim() == 1:
256
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
257
+ head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
258
+ elif head_mask.dim() == 2:
259
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
260
+ assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
261
+ head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
262
+ return head_mask
263
+
264
+
265
+ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
266
+ r""" Base class for all models.
267
+
268
+ :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
269
+ as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
270
+
271
+ Class attributes (overridden by derived classes):
272
+ - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
273
+ - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
274
+
275
+ - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
276
+ - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
277
+ - ``path``: a path (string) to the TensorFlow checkpoint.
278
+
279
+ - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
280
+ """
281
+ config_class = None
282
+ base_model_prefix = ""
283
+
284
+ @property
285
+ def dummy_inputs(self):
286
+ """ Dummy inputs to do a forward pass in the network.
287
+
288
+ Returns:
289
+ torch.Tensor with dummy inputs
290
+ """
291
+ return {"input_ids": torch.tensor(DUMMY_INPUTS)}
292
+
293
+ def __init__(self, config, *inputs, **kwargs):
294
+ super().__init__()
295
+ if not isinstance(config, PretrainedConfig):
296
+ raise ValueError(
297
+ "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
298
+ "To create a model from a pretrained model use "
299
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
300
+ self.__class__.__name__, self.__class__.__name__
301
+ )
302
+ )
303
+ # Save config in model
304
+ self.config = config
305
+
306
+ @property
307
+ def base_model(self):
308
+ return getattr(self, self.base_model_prefix, self)
309
+
310
+ def get_input_embeddings(self):
311
+ """
312
+ Returns the model's input embeddings.
313
+
314
+ Returns:
315
+ :obj:`nn.Module`:
316
+ A torch module mapping vocabulary to hidden states.
317
+ """
318
+ base_model = getattr(self, self.base_model_prefix, self)
319
+ if base_model is not self:
320
+ return base_model.get_input_embeddings()
321
+ else:
322
+ raise NotImplementedError
323
+
324
+ def set_input_embeddings(self, value: nn.Module):
325
+ """
326
+ Set model's input embeddings
327
+
328
+ Args:
329
+ value (:obj:`nn.Module`):
330
+ A module mapping vocabulary to hidden states.
331
+ """
332
+ base_model = getattr(self, self.base_model_prefix, self)
333
+ if base_model is not self:
334
+ base_model.set_input_embeddings(value)
335
+ else:
336
+ raise NotImplementedError
337
+
338
+ def get_output_embeddings(self):
339
+ """
340
+ Returns the model's output embeddings.
341
+
342
+ Returns:
343
+ :obj:`nn.Module`:
344
+ A torch module mapping hidden states to vocabulary.
345
+ """
346
+ return None # Overwrite for models with output embeddings
347
+
348
+ def tie_weights(self):
349
+ """
350
+ Tie the weights between the input embeddings and the output embeddings.
351
+ If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
352
+ the weights instead.
353
+ """
354
+ output_embeddings = self.get_output_embeddings()
355
+ if output_embeddings is not None:
356
+ self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
357
+
358
+ def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
359
+ """ Tie or clone module weights depending of whether we are using TorchScript or not
360
+ """
361
+ if self.config.torchscript:
362
+ output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
363
+ else:
364
+ output_embeddings.weight = input_embeddings.weight
365
+
366
+ if getattr(output_embeddings, "bias", None) is not None:
367
+ output_embeddings.bias.data = torch.nn.functional.pad(
368
+ output_embeddings.bias.data,
369
+ (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],),
370
+ "constant",
371
+ 0,
372
+ )
373
+ if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
374
+ output_embeddings.out_features = input_embeddings.num_embeddings
375
+
376
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None):
377
+ """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
378
+ Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
379
+
380
+ Arguments:
381
+
382
+ new_num_tokens: (`optional`) int:
383
+ New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
384
+ If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
385
+
386
+ Return: ``torch.nn.Embeddings``
387
+ Pointer to the input tokens Embeddings Module of the model
388
+ """
389
+ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
390
+ model_embeds = base_model._resize_token_embeddings(new_num_tokens)
391
+ if new_num_tokens is None:
392
+ return model_embeds
393
+
394
+ # Update base model and current model config
395
+ self.config.vocab_size = new_num_tokens
396
+ base_model.vocab_size = new_num_tokens
397
+
398
+ # Tie weights again if needed
399
+ self.tie_weights()
400
+
401
+ return model_embeds
402
+
403
+ def _resize_token_embeddings(self, new_num_tokens):
404
+ old_embeddings = self.get_input_embeddings()
405
+ new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
406
+ self.set_input_embeddings(new_embeddings)
407
+ return self.get_input_embeddings()
408
+
409
+ def _get_resized_embeddings(
410
+ self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None
411
+ ) -> torch.nn.Embedding:
412
+ """ Build a resized Embedding Module from a provided token Embedding Module.
413
+ Increasing the size will add newly initialized vectors at the end
414
+ Reducing the size will remove vectors from the end
415
+
416
+ Args:
417
+ old_embeddings: ``torch.nn.Embedding``
418
+ Old embeddings to be resized.
419
+ new_num_tokens: (`optional`) int
420
+ New number of tokens in the embedding matrix.
421
+ Increasing the size will add newly initialized vectors at the end
422
+ Reducing the size will remove vectors from the end
423
+ If not provided or None: return the provided token Embedding Module.
424
+ Return: ``torch.nn.Embedding``
425
+ Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
426
+ """
427
+ if new_num_tokens is None:
428
+ return old_embeddings
429
+
430
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
431
+ if old_num_tokens == new_num_tokens:
432
+ return old_embeddings
433
+
434
+ # Build new embeddings
435
+ new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
436
+ new_embeddings.to(old_embeddings.weight.device)
437
+
438
+ # initialize all new embeddings (in particular added tokens)
439
+ self._init_weights(new_embeddings)
440
+
441
+ # Copy token embeddings from the previous weights
442
+ num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
443
+ new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
444
+
445
+ return new_embeddings
446
+
447
+ def init_weights(self):
448
+ """ Initialize and prunes weights if needed. """
449
+ # Initialize weights
450
+ self.apply(self._init_weights)
451
+
452
+ # Prune heads if needed
453
+ if self.config.pruned_heads:
454
+ self.prune_heads(self.config.pruned_heads)
455
+
456
+ # Tie weights if needed
457
+ self.tie_weights()
458
+
459
+ def prune_heads(self, heads_to_prune: Dict):
460
+ """ Prunes heads of the base model.
461
+
462
+ Arguments:
463
+
464
+ heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
465
+ E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
466
+ """
467
+ # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
468
+ for layer, heads in heads_to_prune.items():
469
+ union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
470
+ self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
471
+
472
+ self.base_model._prune_heads(heads_to_prune)
473
+
474
+ def save_pretrained(self, save_directory):
475
+ """ Save a model and its configuration file to a directory, so that it
476
+ can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
477
+
478
+ Arguments:
479
+ save_directory: directory to which to save.
480
+ """
481
+ if os.path.isfile(save_directory):
482
+ logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
483
+ return
484
+ os.makedirs(save_directory, exist_ok=True)
485
+
486
+ # Only save the model itself if we are using distributed training
487
+ model_to_save = self.module if hasattr(self, "module") else self
488
+
489
+ # Attach architecture to the config
490
+ model_to_save.config.architectures = [model_to_save.__class__.__name__]
491
+
492
+ # If we save using the predefined names, we can load using `from_pretrained`
493
+ output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
494
+
495
+ if getattr(self.config, "xla_device", False):
496
+ import torch_xla.core.xla_model as xm
497
+
498
+ if xm.is_master_ordinal():
499
+ # Save configuration file
500
+ model_to_save.config.save_pretrained(save_directory)
501
+ # xm.save takes care of saving only from master
502
+ xm.save(model_to_save.state_dict(), output_model_file)
503
+ else:
504
+ model_to_save.config.save_pretrained(save_directory)
505
+ torch.save(model_to_save.state_dict(), output_model_file)
506
+
507
+ logger.info("Model weights saved in {}".format(output_model_file))
508
+
509
+ @classmethod
510
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
511
+ r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
512
+
513
+ The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
514
+ To train the model, you should first set it back in training mode with ``model.train()``
515
+
516
+ The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
517
+ It is up to you to train those weights with a downstream fine-tuning task.
518
+
519
+ The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
520
+
521
+ Parameters:
522
+ pretrained_model_name_or_path: either:
523
+ - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
524
+ - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
525
+ - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
526
+ - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
527
+ - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
528
+
529
+ model_args: (`optional`) Sequence of positional arguments:
530
+ All remaning positional arguments will be passed to the underlying model's ``__init__`` method
531
+
532
+ config: (`optional`) one of:
533
+ - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
534
+ - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
535
+
536
+ Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
537
+ - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
538
+ - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
539
+ - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
540
+
541
+ state_dict: (`optional`) dict:
542
+ an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
543
+ This option can be used if you want to create a model from a pretrained configuration but load your own weights.
544
+ In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
545
+
546
+ cache_dir: (`optional`) string:
547
+ Path to a directory in which a downloaded pre-trained model
548
+ configuration should be cached if the standard cache should not be used.
549
+
550
+ force_download: (`optional`) boolean, default False:
551
+ Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
552
+
553
+ resume_download: (`optional`) boolean, default False:
554
+ Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
555
+
556
+ proxies: (`optional`) dict, default None:
557
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
558
+ The proxies are used on each request.
559
+
560
+ output_loading_info: (`optional`) boolean:
561
+ Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
562
+
563
+ kwargs: (`optional`) Remaining dictionary of keyword arguments:
564
+ Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
565
+
566
+ - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
567
+ - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
568
+
569
+ Examples::
570
+
571
+ # For example purposes. Not runnable.
572
+ model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
573
+ model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
574
+ model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
575
+ assert model.config.output_attention == True
576
+ # Loading from a TF checkpoint file instead of a PyTorch model (slower)
577
+ config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
578
+ model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
579
+
580
+ """
581
+ config = kwargs.pop("config", None)
582
+ state_dict = kwargs.pop("state_dict", None)
583
+ cache_dir = kwargs.pop("cache_dir", None)
584
+ from_tf = kwargs.pop("from_tf", False)
585
+ force_download = kwargs.pop("force_download", False)
586
+ resume_download = kwargs.pop("resume_download", False)
587
+ proxies = kwargs.pop("proxies", None)
588
+ output_loading_info = kwargs.pop("output_loading_info", False)
589
+ local_files_only = kwargs.pop("local_files_only", False)
590
+ use_cdn = kwargs.pop("use_cdn", True)
591
+
592
+ # Load config if we don't provide a configuration
593
+ if not isinstance(config, PretrainedConfig):
594
+ config_path = config if config is not None else pretrained_model_name_or_path
595
+ config, model_kwargs = cls.config_class.from_pretrained(
596
+ config_path,
597
+ *model_args,
598
+ cache_dir=cache_dir,
599
+ return_unused_kwargs=True,
600
+ force_download=force_download,
601
+ resume_download=resume_download,
602
+ proxies=proxies,
603
+ local_files_only=local_files_only,
604
+ **kwargs,
605
+ )
606
+ else:
607
+ model_kwargs = kwargs
608
+
609
+ # Load model
610
+ if pretrained_model_name_or_path is not None:
611
+ if os.path.isdir(pretrained_model_name_or_path):
612
+ if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
613
+ # Load from a TF 1.0 checkpoint
614
+ archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
615
+ elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
616
+ # Load from a TF 2.0 checkpoint
617
+ archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
618
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
619
+ # Load from a PyTorch checkpoint
620
+ archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
621
+ else:
622
+ raise EnvironmentError(
623
+ "Error no file named {} found in directory {} or `from_tf` set to False".format(
624
+ [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
625
+ pretrained_model_name_or_path,
626
+ )
627
+ )
628
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
629
+ archive_file = pretrained_model_name_or_path
630
+ elif os.path.isfile(pretrained_model_name_or_path + ".index"):
631
+ assert (
632
+ from_tf
633
+ ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
634
+ pretrained_model_name_or_path + ".index"
635
+ )
636
+ archive_file = pretrained_model_name_or_path + ".index"
637
+ else:
638
+ archive_file = hf_bucket_url(
639
+ pretrained_model_name_or_path,
640
+ filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
641
+ use_cdn=use_cdn,
642
+ )
643
+
644
+ try:
645
+ # Load from URL or cache if already cached
646
+ resolved_archive_file = cached_path(
647
+ archive_file,
648
+ cache_dir=cache_dir,
649
+ force_download=force_download,
650
+ proxies=proxies,
651
+ resume_download=resume_download,
652
+ local_files_only=local_files_only,
653
+ )
654
+ if resolved_archive_file is None:
655
+ raise EnvironmentError
656
+ except EnvironmentError:
657
+ msg = (
658
+ f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
659
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
660
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
661
+ )
662
+ raise EnvironmentError(msg)
663
+
664
+ if resolved_archive_file == archive_file:
665
+ logger.info("loading weights file {}".format(archive_file))
666
+ else:
667
+ logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
668
+ else:
669
+ resolved_archive_file = None
670
+
671
+ # Instantiate model.
672
+ model = cls(config, *model_args, **model_kwargs)
673
+
674
+ if state_dict is None and not from_tf:
675
+ try:
676
+ state_dict = torch.load(resolved_archive_file, map_location="cpu")
677
+ except Exception:
678
+ raise OSError(
679
+ "Unable to load weights from pytorch checkpoint file. "
680
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
681
+ )
682
+
683
+ missing_keys = []
684
+ unexpected_keys = []
685
+ error_msgs = []
686
+
687
+ if from_tf:
688
+ if resolved_archive_file.endswith(".index"):
689
+ # Load from a TensorFlow 1.X checkpoint - provided by original authors
690
+ model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
691
+ else:
692
+ # Load from our TensorFlow 2.0 checkpoints
693
+ try:
694
+ from transformers import load_tf2_checkpoint_in_pytorch_model
695
+
696
+ model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
697
+ except ImportError:
698
+ logger.error(
699
+ "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
700
+ "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
701
+ )
702
+ raise
703
+ else:
704
+ # Convert old format to new format if needed from a PyTorch state_dict
705
+ old_keys = []
706
+ new_keys = []
707
+ for key in state_dict.keys():
708
+ new_key = None
709
+ if "gamma" in key:
710
+ new_key = key.replace("gamma", "weight")
711
+ if "beta" in key:
712
+ new_key = key.replace("beta", "bias")
713
+ if new_key:
714
+ old_keys.append(key)
715
+ new_keys.append(new_key)
716
+ for old_key, new_key in zip(old_keys, new_keys):
717
+ state_dict[new_key] = state_dict.pop(old_key)
718
+
719
+ # copy state_dict so _load_from_state_dict can modify it
720
+ metadata = getattr(state_dict, "_metadata", None)
721
+ state_dict = state_dict.copy()
722
+ if metadata is not None:
723
+ state_dict._metadata = metadata
724
+
725
+ ##############################################################################################
726
+ # Print out state_dict's contents: keys
727
+ '''
728
+ for key, _ in state_dict.items():
729
+ print(key)
730
+ '''
731
+ ##############################################################################################
732
+
733
+
734
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
735
+ # so we need to apply the function recursively.
736
+ def load(module: nn.Module, prefix=""):
737
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
738
+ module._load_from_state_dict(
739
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
740
+ )
741
+ for name, child in module._modules.items():
742
+ if child is not None:
743
+ load(child, prefix + name + ".")
744
+
745
+ # Make sure we are able to load base models as well as derived models (with heads)
746
+ start_prefix = ""
747
+ model_to_load = model
748
+ has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
749
+ if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
750
+ start_prefix = cls.base_model_prefix + "."
751
+ if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
752
+ model_to_load = getattr(model, cls.base_model_prefix)
753
+
754
+ load(model_to_load, prefix=start_prefix)
755
+
756
+ if model.__class__.__name__ != model_to_load.__class__.__name__:
757
+ base_model_state_dict = model_to_load.state_dict().keys()
758
+ head_model_state_dict_without_base_prefix = [
759
+ key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
760
+ ]
761
+
762
+ missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
763
+
764
+ if len(unexpected_keys) > 0:
765
+ logger.warning(
766
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
767
+ f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
768
+ f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
769
+ f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
770
+ f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
771
+ f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
772
+ )
773
+ else:
774
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
775
+ if len(missing_keys) > 0:
776
+ logger.warning(
777
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
778
+ f"and are newly initialized: {missing_keys}\n"
779
+ f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
780
+ )
781
+ else:
782
+ logger.info(
783
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
784
+ f"If your task is similar to the task the model of the ckeckpoint was trained on, "
785
+ f"you can already use {model.__class__.__name__} for predictions without further training."
786
+ )
787
+ if len(error_msgs) > 0:
788
+ raise RuntimeError(
789
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
790
+ model.__class__.__name__, "\n\t".join(error_msgs)
791
+ )
792
+ )
793
+ model.tie_weights() # make sure token embedding weights are still tied if needed
794
+
795
+ # Set model in evaluation mode to deactivate DropOut modules by default
796
+ model.eval()
797
+
798
+ if output_loading_info:
799
+ loading_info = {
800
+ "missing_keys": missing_keys,
801
+ "unexpected_keys": unexpected_keys,
802
+ "error_msgs": error_msgs,
803
+ }
804
+ return model, loading_info
805
+
806
+ if hasattr(config, "xla_device") and config.xla_device:
807
+ import torch_xla.core.xla_model as xm
808
+
809
+ model = xm.send_cpu_data_to_device(model, xm.xla_device())
810
+ model.to(xm.xla_device())
811
+
812
+ return model
813
+
814
+
815
+ class Conv1D(nn.Module):
816
+ def __init__(self, nf, nx):
817
+ """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
818
+ Basically works like a Linear layer but the weights are transposed
819
+ """
820
+ super().__init__()
821
+ self.nf = nf
822
+ w = torch.empty(nx, nf)
823
+ nn.init.normal_(w, std=0.02)
824
+ self.weight = nn.Parameter(w)
825
+ self.bias = nn.Parameter(torch.zeros(nf))
826
+
827
+ def forward(self, x):
828
+ size_out = x.size()[:-1] + (self.nf,)
829
+ x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
830
+ x = x.view(*size_out)
831
+ return x
832
+
833
+
834
+ class PoolerStartLogits(nn.Module):
835
+ """ Compute SQuAD start_logits from sequence hidden states. """
836
+
837
+ def __init__(self, config):
838
+ super().__init__()
839
+ self.dense = nn.Linear(config.hidden_size, 1)
840
+
841
+ def forward(self, hidden_states, p_mask=None):
842
+ """ Args:
843
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
844
+ invalid position mask such as query and special symbols (PAD, SEP, CLS)
845
+ 1.0 means token should be masked.
846
+ """
847
+ x = self.dense(hidden_states).squeeze(-1)
848
+
849
+ if p_mask is not None:
850
+ if next(self.parameters()).dtype == torch.float16:
851
+ x = x * (1 - p_mask) - 65500 * p_mask
852
+ else:
853
+ x = x * (1 - p_mask) - 1e30 * p_mask
854
+
855
+ return x
856
+
857
+
858
+ class PoolerEndLogits(nn.Module):
859
+ """ Compute SQuAD end_logits from sequence hidden states and start token hidden state.
860
+ """
861
+
862
+ def __init__(self, config):
863
+ super().__init__()
864
+ self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
865
+ self.activation = nn.Tanh()
866
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
867
+ self.dense_1 = nn.Linear(config.hidden_size, 1)
868
+
869
+ def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
870
+ """ Args:
871
+ One of ``start_states``, ``start_positions`` should be not None.
872
+ If both are set, ``start_positions`` overrides ``start_states``.
873
+
874
+ **start_states**: ``torch.LongTensor`` of shape identical to hidden_states
875
+ hidden states of the first tokens for the labeled span.
876
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
877
+ position of the first token for the labeled span:
878
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
879
+ Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
880
+ 1.0 means token should be masked.
881
+ """
882
+ assert (
883
+ start_states is not None or start_positions is not None
884
+ ), "One of start_states, start_positions should be not None"
885
+ if start_positions is not None:
886
+ slen, hsz = hidden_states.shape[-2:]
887
+ start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
888
+ start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
889
+ start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
890
+
891
+ x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
892
+ x = self.activation(x)
893
+ x = self.LayerNorm(x)
894
+ x = self.dense_1(x).squeeze(-1)
895
+
896
+ if p_mask is not None:
897
+ if next(self.parameters()).dtype == torch.float16:
898
+ x = x * (1 - p_mask) - 65500 * p_mask
899
+ else:
900
+ x = x * (1 - p_mask) - 1e30 * p_mask
901
+
902
+ return x
903
+
904
+
905
+ class PoolerAnswerClass(nn.Module):
906
+ """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
907
+
908
+ def __init__(self, config):
909
+ super().__init__()
910
+ self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
911
+ self.activation = nn.Tanh()
912
+ self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
913
+
914
+ def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
915
+ """
916
+ Args:
917
+ One of ``start_states``, ``start_positions`` should be not None.
918
+ If both are set, ``start_positions`` overrides ``start_states``.
919
+
920
+ **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
921
+ hidden states of the first tokens for the labeled span.
922
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
923
+ position of the first token for the labeled span.
924
+ **cls_index**: torch.LongTensor of shape ``(batch_size,)``
925
+ position of the CLS token. If None, take the last token.
926
+
927
+ note(Original repo):
928
+ no dependency on end_feature so that we can obtain one single `cls_logits`
929
+ for each sample
930
+ """
931
+ hsz = hidden_states.shape[-1]
932
+ assert (
933
+ start_states is not None or start_positions is not None
934
+ ), "One of start_states, start_positions should be not None"
935
+ if start_positions is not None:
936
+ start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
937
+ start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
938
+
939
+ if cls_index is not None:
940
+ cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
941
+ cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
942
+ else:
943
+ cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
944
+
945
+ x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
946
+ x = self.activation(x)
947
+ x = self.dense_1(x).squeeze(-1)
948
+
949
+ return x
950
+
951
+
952
+ class SQuADHead(nn.Module):
953
+ r""" A SQuAD head inspired by XLNet.
954
+
955
+ Parameters:
956
+ config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
957
+
958
+ Inputs:
959
+ **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
960
+ hidden states of sequence tokens
961
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
962
+ position of the first token for the labeled span.
963
+ **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
964
+ position of the last token for the labeled span.
965
+ **cls_index**: torch.LongTensor of shape ``(batch_size,)``
966
+ position of the CLS token. If None, take the last token.
967
+ **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
968
+ Whether the question has a possible answer in the paragraph or not.
969
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
970
+ Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
971
+ 1.0 means token should be masked.
972
+
973
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
974
+ **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
975
+ Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
976
+ **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
977
+ ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
978
+ Log probabilities for the top config.start_n_top start token possibilities (beam-search).
979
+ **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
980
+ ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
981
+ Indices for the top config.start_n_top start token possibilities (beam-search).
982
+ **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
983
+ ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
984
+ Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
985
+ **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
986
+ ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
987
+ Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
988
+ **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
989
+ ``torch.FloatTensor`` of shape ``(batch_size,)``
990
+ Log probabilities for the ``is_impossible`` label of the answers.
991
+ """
992
+
993
+ def __init__(self, config):
994
+ super().__init__()
995
+ self.start_n_top = config.start_n_top
996
+ self.end_n_top = config.end_n_top
997
+
998
+ self.start_logits = PoolerStartLogits(config)
999
+ self.end_logits = PoolerEndLogits(config)
1000
+ self.answer_class = PoolerAnswerClass(config)
1001
+
1002
+ def forward(
1003
+ self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
1004
+ ):
1005
+ outputs = ()
1006
+
1007
+ start_logits = self.start_logits(hidden_states, p_mask=p_mask)
1008
+
1009
+ if start_positions is not None and end_positions is not None:
1010
+ # If we are on multi-GPU, let's remove the dimension added by batch splitting
1011
+ for x in (start_positions, end_positions, cls_index, is_impossible):
1012
+ if x is not None and x.dim() > 1:
1013
+ x.squeeze_(-1)
1014
+
1015
+ # during training, compute the end logits based on the ground truth of the start position
1016
+ end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
1017
+
1018
+ loss_fct = CrossEntropyLoss()
1019
+ start_loss = loss_fct(start_logits, start_positions)
1020
+ end_loss = loss_fct(end_logits, end_positions)
1021
+ total_loss = (start_loss + end_loss) / 2
1022
+
1023
+ if cls_index is not None and is_impossible is not None:
1024
+ # Predict answerability from the representation of CLS and START
1025
+ cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
1026
+ loss_fct_cls = nn.BCEWithLogitsLoss()
1027
+ cls_loss = loss_fct_cls(cls_logits, is_impossible)
1028
+
1029
+ # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
1030
+ total_loss += cls_loss * 0.5
1031
+
1032
+ outputs = (total_loss,) + outputs
1033
+
1034
+ else:
1035
+ # during inference, compute the end logits based on beam search
1036
+ bsz, slen, hsz = hidden_states.size()
1037
+ start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
1038
+
1039
+ start_top_log_probs, start_top_index = torch.topk(
1040
+ start_log_probs, self.start_n_top, dim=-1
1041
+ ) # shape (bsz, start_n_top)
1042
+ start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
1043
+ start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
1044
+ start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
1045
+
1046
+ hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
1047
+ start_states
1048
+ ) # shape (bsz, slen, start_n_top, hsz)
1049
+ p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
1050
+ end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
1051
+ end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
1052
+
1053
+ end_top_log_probs, end_top_index = torch.topk(
1054
+ end_log_probs, self.end_n_top, dim=1
1055
+ ) # shape (bsz, end_n_top, start_n_top)
1056
+ end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
1057
+ end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
1058
+
1059
+ start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
1060
+ cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
1061
+
1062
+ outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits,) + outputs
1063
+
1064
+ # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
1065
+ # or (if labels are provided) (total_loss,)
1066
+ return outputs
1067
+
1068
+
1069
+ class SequenceSummary(nn.Module):
1070
+ r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
1071
+ Args of the config class:
1072
+ summary_type:
1073
+ - 'last' => [default] take the last token hidden state (like XLNet)
1074
+ - 'first' => take the first token hidden state (like Bert)
1075
+ - 'mean' => take the mean of all tokens hidden states
1076
+ - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
1077
+ - 'attn' => Not implemented now, use multi-head attention
1078
+ summary_use_proj: Add a projection after the vector extraction
1079
+ summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
1080
+ summary_activation: 'tanh' or another string => add an activation to the output, Other => no activation. Default
1081
+ summary_first_dropout: Add a dropout before the projection and activation
1082
+ summary_last_dropout: Add a dropout after the projection and activation
1083
+ """
1084
+
1085
+ def __init__(self, config: PretrainedConfig):
1086
+ super().__init__()
1087
+
1088
+ self.summary_type = getattr(config, "summary_type", "last")
1089
+ if self.summary_type == "attn":
1090
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
1091
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
1092
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
1093
+ raise NotImplementedError
1094
+
1095
+ self.summary = Identity()
1096
+ if hasattr(config, "summary_use_proj") and config.summary_use_proj:
1097
+ if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
1098
+ num_classes = config.num_labels
1099
+ else:
1100
+ num_classes = config.hidden_size
1101
+ self.summary = nn.Linear(config.hidden_size, num_classes)
1102
+
1103
+ activation_string = getattr(config, "summary_activation", None)
1104
+ self.activation: Callable = (get_activation(activation_string) if activation_string else Identity())
1105
+
1106
+ self.first_dropout = Identity()
1107
+ if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
1108
+ self.first_dropout = nn.Dropout(config.summary_first_dropout)
1109
+
1110
+ self.last_dropout = Identity()
1111
+ if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
1112
+ self.last_dropout = nn.Dropout(config.summary_last_dropout)
1113
+
1114
+ def forward(self, hidden_states, cls_index=None):
1115
+ """ hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer.
1116
+ cls_index: [optional] position of the classification token if summary_type == 'cls_index',
1117
+ shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
1118
+ if summary_type == 'cls_index' and cls_index is None:
1119
+ we take the last token of the sequence as classification token
1120
+ """
1121
+ if self.summary_type == "last":
1122
+ output = hidden_states[:, -1]
1123
+ elif self.summary_type == "first":
1124
+ output = hidden_states[:, 0]
1125
+ elif self.summary_type == "mean":
1126
+ output = hidden_states.mean(dim=1)
1127
+ elif self.summary_type == "cls_index":
1128
+ if cls_index is None:
1129
+ cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long,)
1130
+ else:
1131
+ cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
1132
+ cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
1133
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
1134
+ output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
1135
+ elif self.summary_type == "attn":
1136
+ raise NotImplementedError
1137
+
1138
+ output = self.first_dropout(output)
1139
+ output = self.summary(output)
1140
+ output = self.activation(output)
1141
+ output = self.last_dropout(output)
1142
+
1143
+ return output
1144
+
1145
+
1146
+ def prune_linear_layer(layer, index, dim=0):
1147
+ """ Prune a linear layer (a model parameters) to keep only entries in index.
1148
+ Return the pruned layer as a new layer with requires_grad=True.
1149
+ Used to remove heads.
1150
+ """
1151
+ index = index.to(layer.weight.device)
1152
+ W = layer.weight.index_select(dim, index).clone().detach()
1153
+ if layer.bias is not None:
1154
+ if dim == 1:
1155
+ b = layer.bias.clone().detach()
1156
+ else:
1157
+ b = layer.bias[index].clone().detach()
1158
+ new_size = list(layer.weight.size())
1159
+ new_size[dim] = len(index)
1160
+ new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
1161
+ new_layer.weight.requires_grad = False
1162
+ new_layer.weight.copy_(W.contiguous())
1163
+ new_layer.weight.requires_grad = True
1164
+ if layer.bias is not None:
1165
+ new_layer.bias.requires_grad = False
1166
+ new_layer.bias.copy_(b.contiguous())
1167
+ new_layer.bias.requires_grad = True
1168
+ return new_layer
1169
+
1170
+
1171
+ def prune_conv1d_layer(layer, index, dim=1):
1172
+ """ Prune a Conv1D layer (a model parameters) to keep only entries in index.
1173
+ A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
1174
+ Return the pruned layer as a new layer with requires_grad=True.
1175
+ Used to remove heads.
1176
+ """
1177
+ index = index.to(layer.weight.device)
1178
+ W = layer.weight.index_select(dim, index).clone().detach()
1179
+ if dim == 0:
1180
+ b = layer.bias.clone().detach()
1181
+ else:
1182
+ b = layer.bias[index].clone().detach()
1183
+ new_size = list(layer.weight.size())
1184
+ new_size[dim] = len(index)
1185
+ new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
1186
+ new_layer.weight.requires_grad = False
1187
+ new_layer.weight.copy_(W.contiguous())
1188
+ new_layer.weight.requires_grad = True
1189
+ new_layer.bias.requires_grad = False
1190
+ new_layer.bias.copy_(b.contiguous())
1191
+ new_layer.bias.requires_grad = True
1192
+ return new_layer
1193
+
1194
+
1195
+ def prune_layer(layer, index, dim=None):
1196
+ """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
1197
+ Return the pruned layer as a new layer with requires_grad=True.
1198
+ Used to remove heads.
1199
+ """
1200
+ if isinstance(layer, nn.Linear):
1201
+ return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
1202
+ elif isinstance(layer, Conv1D):
1203
+ return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
1204
+ else:
1205
+ raise ValueError("Can't prune layer of class {}".format(layer.__class__))
1206
+
1207
+
1208
+ def apply_chunking_to_forward(
1209
+ chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors
1210
+ ) -> torch.Tensor:
1211
+ """
1212
+ This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`.
1213
+ It then applies a layer `forward_fn` to each chunk independently to save memory.
1214
+ If the `forward_fn` is independent across the `chunk_dim` this function will yield the
1215
+ same result as not applying it.
1216
+
1217
+ Args:
1218
+ chunk_size: int - the chunk size of a chunked tensor. `num_chunks` = `len(input_tensors[0]) / chunk_size`
1219
+ chunk_dim: int - the dimension over which the input_tensors should be chunked
1220
+ forward_fn: fn - the forward fn of the model
1221
+ input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked
1222
+ Returns:
1223
+ a Tensor with the same shape the foward_fn would have given if applied
1224
+
1225
+
1226
+ Examples::
1227
+
1228
+ # rename the usual forward() fn to forward_chunk()
1229
+ def forward_chunk(self, hidden_states):
1230
+ hidden_states = self.decoder(hidden_states)
1231
+ return hidden_states
1232
+
1233
+ # implement a chunked forward function
1234
+ def forward(self, hidden_states):
1235
+ return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
1236
+ """
1237
+
1238
+ assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
1239
+ tensor_shape = input_tensors[0].shape
1240
+ assert all(
1241
+ input_tensor.shape == tensor_shape for input_tensor in input_tensors
1242
+ ), "All input tenors have to be of the same shape"
1243
+
1244
+ # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability
1245
+ num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
1246
+ assert num_args_in_forward_chunk_fn == len(
1247
+ input_tensors
1248
+ ), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format(
1249
+ num_args_in_forward_chunk_fn, len(input_tensors)
1250
+ )
1251
+
1252
+ if chunk_size > 0:
1253
+ assert (
1254
+ input_tensors[0].shape[chunk_dim] % chunk_size == 0
1255
+ ), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
1256
+ input_tensors[0].shape[chunk_dim], chunk_size
1257
+ )
1258
+
1259
+ num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
1260
+
1261
+ # chunk input tensor into tuples
1262
+ input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
1263
+ # apply forward fn to every tuple
1264
+ output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
1265
+ # concatenate output at same dimension
1266
+ return torch.cat(output_chunks, dim=chunk_dim)
1267
+
1268
+ return forward_fn(*input_tensors)
LAVT-RIS/bert/tokenization_bert.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes."""
16
+
17
+
18
+ import collections
19
+ import logging
20
+ import os
21
+ import unicodedata
22
+ from typing import List, Optional
23
+
24
+ from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
30
+
31
+ PRETRAINED_VOCAB_FILES_MAP = {
32
+ "vocab_file": {
33
+ "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
34
+ "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
35
+ "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
36
+ "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
37
+ "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
38
+ "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
39
+ "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
40
+ "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
41
+ "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
42
+ "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
43
+ "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
44
+ "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
45
+ "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
46
+ "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
47
+ "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
48
+ "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
49
+ "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
50
+ "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt",
51
+ }
52
+ }
53
+
54
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
55
+ "bert-base-uncased": 512,
56
+ "bert-large-uncased": 512,
57
+ "bert-base-cased": 512,
58
+ "bert-large-cased": 512,
59
+ "bert-base-multilingual-uncased": 512,
60
+ "bert-base-multilingual-cased": 512,
61
+ "bert-base-chinese": 512,
62
+ "bert-base-german-cased": 512,
63
+ "bert-large-uncased-whole-word-masking": 512,
64
+ "bert-large-cased-whole-word-masking": 512,
65
+ "bert-large-uncased-whole-word-masking-finetuned-squad": 512,
66
+ "bert-large-cased-whole-word-masking-finetuned-squad": 512,
67
+ "bert-base-cased-finetuned-mrpc": 512,
68
+ "bert-base-german-dbmdz-cased": 512,
69
+ "bert-base-german-dbmdz-uncased": 512,
70
+ "TurkuNLP/bert-base-finnish-cased-v1": 512,
71
+ "TurkuNLP/bert-base-finnish-uncased-v1": 512,
72
+ "wietsedv/bert-base-dutch-cased": 512,
73
+ }
74
+
75
+ PRETRAINED_INIT_CONFIGURATION = {
76
+ "bert-base-uncased": {"do_lower_case": True},
77
+ "bert-large-uncased": {"do_lower_case": True},
78
+ "bert-base-cased": {"do_lower_case": False},
79
+ "bert-large-cased": {"do_lower_case": False},
80
+ "bert-base-multilingual-uncased": {"do_lower_case": True},
81
+ "bert-base-multilingual-cased": {"do_lower_case": False},
82
+ "bert-base-chinese": {"do_lower_case": False},
83
+ "bert-base-german-cased": {"do_lower_case": False},
84
+ "bert-large-uncased-whole-word-masking": {"do_lower_case": True},
85
+ "bert-large-cased-whole-word-masking": {"do_lower_case": False},
86
+ "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
87
+ "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
88
+ "bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
89
+ "bert-base-german-dbmdz-cased": {"do_lower_case": False},
90
+ "bert-base-german-dbmdz-uncased": {"do_lower_case": True},
91
+ "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
92
+ "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
93
+ "wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
94
+ }
95
+
96
+
97
+ def load_vocab(vocab_file):
98
+ """Loads a vocabulary file into a dictionary."""
99
+ vocab = collections.OrderedDict()
100
+ with open(vocab_file, "r", encoding="utf-8") as reader:
101
+ tokens = reader.readlines()
102
+ for index, token in enumerate(tokens):
103
+ token = token.rstrip("\n")
104
+ vocab[token] = index
105
+ return vocab
106
+
107
+
108
+ def whitespace_tokenize(text):
109
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
110
+ text = text.strip()
111
+ if not text:
112
+ return []
113
+ tokens = text.split()
114
+ return tokens
115
+
116
+
117
+ class BertTokenizer(PreTrainedTokenizer):
118
+ r"""
119
+ Constructs a BERT tokenizer. Based on WordPiece.
120
+
121
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
122
+ should refer to the superclass for more information regarding methods.
123
+
124
+ Args:
125
+ vocab_file (:obj:`string`):
126
+ File containing the vocabulary.
127
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
128
+ Whether to lowercase the input when tokenizing.
129
+ do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
130
+ Whether to do basic tokenization before WordPiece.
131
+ never_split (:obj:`Iterable`, `optional`, defaults to :obj:`None`):
132
+ Collection of tokens which will never be split during tokenization. Only has an effect when
133
+ :obj:`do_basic_tokenize=True`
134
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
135
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
136
+ token instead.
137
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
138
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
139
+ for sequence classification or for a text and a question for question answering.
140
+ It is also used as the last token of a sequence built with special tokens.
141
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
142
+ The token used for padding, for example when batching sequences of different lengths.
143
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
144
+ The classifier token which is used when doing sequence classification (classification of the whole
145
+ sequence instead of per-token classification). It is the first token of the sequence when built with
146
+ special tokens.
147
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
148
+ The token used for masking values. This is the token used when training this model with masked language
149
+ modeling. This is the token which the model will try to predict.
150
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
151
+ Whether to tokenize Chinese characters.
152
+ This should likely be deactivated for Japanese:
153
+ see: https://github.com/huggingface/transformers/issues/328
154
+ """
155
+
156
+ vocab_files_names = VOCAB_FILES_NAMES
157
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
158
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
159
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
160
+
161
+ def __init__(
162
+ self,
163
+ vocab_file,
164
+ do_lower_case=True,
165
+ do_basic_tokenize=True,
166
+ never_split=None,
167
+ unk_token="[UNK]",
168
+ sep_token="[SEP]",
169
+ pad_token="[PAD]",
170
+ cls_token="[CLS]",
171
+ mask_token="[MASK]",
172
+ tokenize_chinese_chars=True,
173
+ **kwargs
174
+ ):
175
+ super().__init__(
176
+ unk_token=unk_token,
177
+ sep_token=sep_token,
178
+ pad_token=pad_token,
179
+ cls_token=cls_token,
180
+ mask_token=mask_token,
181
+ **kwargs,
182
+ )
183
+
184
+ if not os.path.isfile(vocab_file):
185
+ raise ValueError(
186
+ "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
187
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
188
+ )
189
+ self.vocab = load_vocab(vocab_file)
190
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
191
+ self.do_basic_tokenize = do_basic_tokenize
192
+ if do_basic_tokenize:
193
+ self.basic_tokenizer = BasicTokenizer(
194
+ do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars
195
+ )
196
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
197
+
198
+ @property
199
+ def vocab_size(self):
200
+ return len(self.vocab)
201
+
202
+ def get_vocab(self):
203
+ return dict(self.vocab, **self.added_tokens_encoder)
204
+
205
+ def _tokenize(self, text):
206
+ split_tokens = []
207
+ if self.do_basic_tokenize:
208
+ for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
209
+
210
+ # If the token is part of the never_split set
211
+ if token in self.basic_tokenizer.never_split:
212
+ split_tokens.append(token)
213
+ else:
214
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
215
+ else:
216
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
217
+ return split_tokens
218
+
219
+ def _convert_token_to_id(self, token):
220
+ """ Converts a token (str) in an id using the vocab. """
221
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
222
+
223
+ def _convert_id_to_token(self, index):
224
+ """Converts an index (integer) in a token (str) using the vocab."""
225
+ return self.ids_to_tokens.get(index, self.unk_token)
226
+
227
+ def convert_tokens_to_string(self, tokens):
228
+ """ Converts a sequence of tokens (string) in a single string. """
229
+ out_string = " ".join(tokens).replace(" ##", "").strip()
230
+ return out_string
231
+
232
+ def build_inputs_with_special_tokens(
233
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
234
+ ) -> List[int]:
235
+ """
236
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
237
+ by concatenating and adding special tokens.
238
+ A BERT sequence has the following format:
239
+
240
+ - single sequence: ``[CLS] X [SEP]``
241
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
242
+
243
+ Args:
244
+ token_ids_0 (:obj:`List[int]`):
245
+ List of IDs to which the special tokens will be added
246
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
247
+ Optional second list of IDs for sequence pairs.
248
+
249
+ Returns:
250
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
251
+ """
252
+ if token_ids_1 is None:
253
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
254
+ cls = [self.cls_token_id]
255
+ sep = [self.sep_token_id]
256
+ return cls + token_ids_0 + sep + token_ids_1 + sep
257
+
258
+ def get_special_tokens_mask(
259
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
260
+ ) -> List[int]:
261
+ """
262
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
263
+ special tokens using the tokenizer ``prepare_for_model`` method.
264
+
265
+ Args:
266
+ token_ids_0 (:obj:`List[int]`):
267
+ List of ids.
268
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
269
+ Optional second list of IDs for sequence pairs.
270
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
271
+ Set to True if the token list is already formatted with special tokens for the model
272
+
273
+ Returns:
274
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
275
+ """
276
+
277
+ if already_has_special_tokens:
278
+ if token_ids_1 is not None:
279
+ raise ValueError(
280
+ "You should not supply a second sequence if the provided sequence of "
281
+ "ids is already formated with special tokens for the model."
282
+ )
283
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
284
+
285
+ if token_ids_1 is not None:
286
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
287
+ return [1] + ([0] * len(token_ids_0)) + [1]
288
+
289
+ def create_token_type_ids_from_sequences(
290
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
291
+ ) -> List[int]:
292
+ """
293
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
294
+ A BERT sequence pair mask has the following format:
295
+
296
+ ::
297
+
298
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
299
+ | first sequence | second sequence |
300
+
301
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
302
+
303
+ Args:
304
+ token_ids_0 (:obj:`List[int]`):
305
+ List of ids.
306
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
307
+ Optional second list of IDs for sequence pairs.
308
+
309
+ Returns:
310
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
311
+ sequence(s).
312
+ """
313
+ sep = [self.sep_token_id]
314
+ cls = [self.cls_token_id]
315
+ if token_ids_1 is None:
316
+ return len(cls + token_ids_0 + sep) * [0]
317
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
318
+
319
+ def save_vocabulary(self, vocab_path):
320
+ """
321
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
322
+
323
+ Args:
324
+ vocab_path (:obj:`str`):
325
+ The directory in which to save the vocabulary.
326
+
327
+ Returns:
328
+ :obj:`Tuple(str)`: Paths to the files saved.
329
+ """
330
+ index = 0
331
+ if os.path.isdir(vocab_path):
332
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
333
+ else:
334
+ vocab_file = vocab_path
335
+ with open(vocab_file, "w", encoding="utf-8") as writer:
336
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
337
+ if index != token_index:
338
+ logger.warning(
339
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
340
+ " Please check that the vocabulary is not corrupted!".format(vocab_file)
341
+ )
342
+ index = token_index
343
+ writer.write(token + "\n")
344
+ index += 1
345
+ return (vocab_file,)
346
+
347
+
348
+ class BasicTokenizer(object):
349
+ """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
350
+
351
+ def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
352
+ """ Constructs a BasicTokenizer.
353
+
354
+ Args:
355
+ **do_lower_case**: Whether to lower case the input.
356
+ **never_split**: (`optional`) list of str
357
+ Kept for backward compatibility purposes.
358
+ Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
359
+ List of token not to split.
360
+ **tokenize_chinese_chars**: (`optional`) boolean (default True)
361
+ Whether to tokenize Chinese characters.
362
+ This should likely be deactivated for Japanese:
363
+ see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
364
+ """
365
+ if never_split is None:
366
+ never_split = []
367
+ self.do_lower_case = do_lower_case
368
+ self.never_split = set(never_split)
369
+ self.tokenize_chinese_chars = tokenize_chinese_chars
370
+
371
+ def tokenize(self, text, never_split=None):
372
+ """ Basic Tokenization of a piece of text.
373
+ Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
374
+
375
+ Args:
376
+ **never_split**: (`optional`) list of str
377
+ Kept for backward compatibility purposes.
378
+ Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
379
+ List of token not to split.
380
+ """
381
+ # union() returns a new set by concatenating the two sets.
382
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
383
+
384
+ # This was added on November 1st, 2018 for the multilingual and Chinese
385
+ # models. This is also applied to the English models now, but it doesn't
386
+ # matter since the English models were not trained on any Chinese data
387
+ # and generally don't have any Chinese data in them (there are Chinese
388
+ # characters in the vocabulary because Wikipedia does have some Chinese
389
+ # words in the English Wikipedia.).
390
+ if self.tokenize_chinese_chars:
391
+ text = self._tokenize_chinese_chars(text)
392
+ orig_tokens = whitespace_tokenize(text)
393
+ split_tokens = []
394
+ for token in orig_tokens:
395
+ if self.do_lower_case and token not in never_split:
396
+ token = token.lower()
397
+ token = self._run_strip_accents(token)
398
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
399
+
400
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
401
+ return output_tokens
402
+
403
+ def _run_strip_accents(self, text):
404
+ """Strips accents from a piece of text."""
405
+ text = unicodedata.normalize("NFD", text)
406
+ output = []
407
+ for char in text:
408
+ cat = unicodedata.category(char)
409
+ if cat == "Mn":
410
+ continue
411
+ output.append(char)
412
+ return "".join(output)
413
+
414
+ def _run_split_on_punc(self, text, never_split=None):
415
+ """Splits punctuation on a piece of text."""
416
+ if never_split is not None and text in never_split:
417
+ return [text]
418
+ chars = list(text)
419
+ i = 0
420
+ start_new_word = True
421
+ output = []
422
+ while i < len(chars):
423
+ char = chars[i]
424
+ if _is_punctuation(char):
425
+ output.append([char])
426
+ start_new_word = True
427
+ else:
428
+ if start_new_word:
429
+ output.append([])
430
+ start_new_word = False
431
+ output[-1].append(char)
432
+ i += 1
433
+
434
+ return ["".join(x) for x in output]
435
+
436
+ def _tokenize_chinese_chars(self, text):
437
+ """Adds whitespace around any CJK character."""
438
+ output = []
439
+ for char in text:
440
+ cp = ord(char)
441
+ if self._is_chinese_char(cp):
442
+ output.append(" ")
443
+ output.append(char)
444
+ output.append(" ")
445
+ else:
446
+ output.append(char)
447
+ return "".join(output)
448
+
449
+ def _is_chinese_char(self, cp):
450
+ """Checks whether CP is the codepoint of a CJK character."""
451
+ # This defines a "chinese character" as anything in the CJK Unicode block:
452
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
453
+ #
454
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
455
+ # despite its name. The modern Korean Hangul alphabet is a different block,
456
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
457
+ # space-separated words, so they are not treated specially and handled
458
+ # like the all of the other languages.
459
+ if (
460
+ (cp >= 0x4E00 and cp <= 0x9FFF)
461
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
462
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
463
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
464
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
465
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
466
+ or (cp >= 0xF900 and cp <= 0xFAFF)
467
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
468
+ ): #
469
+ return True
470
+
471
+ return False
472
+
473
+ def _clean_text(self, text):
474
+ """Performs invalid character removal and whitespace cleanup on text."""
475
+ output = []
476
+ for char in text:
477
+ cp = ord(char)
478
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
479
+ continue
480
+ if _is_whitespace(char):
481
+ output.append(" ")
482
+ else:
483
+ output.append(char)
484
+ return "".join(output)
485
+
486
+
487
+ class WordpieceTokenizer(object):
488
+ """Runs WordPiece tokenization."""
489
+
490
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
491
+ self.vocab = vocab
492
+ self.unk_token = unk_token
493
+ self.max_input_chars_per_word = max_input_chars_per_word
494
+
495
+ def tokenize(self, text):
496
+ """Tokenizes a piece of text into its word pieces.
497
+
498
+ This uses a greedy longest-match-first algorithm to perform tokenization
499
+ using the given vocabulary.
500
+
501
+ For example:
502
+ input = "unaffable"
503
+ output = ["un", "##aff", "##able"]
504
+
505
+ Args:
506
+ text: A single token or whitespace separated tokens. This should have
507
+ already been passed through `BasicTokenizer`.
508
+
509
+ Returns:
510
+ A list of wordpiece tokens.
511
+ """
512
+
513
+ output_tokens = []
514
+ for token in whitespace_tokenize(text):
515
+ chars = list(token)
516
+ if len(chars) > self.max_input_chars_per_word:
517
+ output_tokens.append(self.unk_token)
518
+ continue
519
+
520
+ is_bad = False
521
+ start = 0
522
+ sub_tokens = []
523
+ while start < len(chars):
524
+ end = len(chars)
525
+ cur_substr = None
526
+ while start < end:
527
+ substr = "".join(chars[start:end])
528
+ if start > 0:
529
+ substr = "##" + substr
530
+ if substr in self.vocab:
531
+ cur_substr = substr
532
+ break
533
+ end -= 1
534
+ if cur_substr is None:
535
+ is_bad = True
536
+ break
537
+ sub_tokens.append(cur_substr)
538
+ start = end
539
+
540
+ if is_bad:
541
+ output_tokens.append(self.unk_token)
542
+ else:
543
+ output_tokens.extend(sub_tokens)
544
+ return output_tokens
545
+
LAVT-RIS/bert/tokenization_utils.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Tokenization classes for python tokenizers.
16
+ For fast tokenizers (provided by HuggingFace's tokenizers library) see tokenization_utils_fast.py
17
+ """
18
+
19
+ import itertools
20
+ import logging
21
+ import re
22
+ import unicodedata
23
+ from typing import Dict, List, Optional, Tuple, Union
24
+
25
+ from .file_utils import add_end_docstrings
26
+ from .tokenization_utils_base import (
27
+ ENCODE_KWARGS_DOCSTRING,
28
+ ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
29
+ AddedToken,
30
+ BatchEncoding,
31
+ EncodedInput,
32
+ EncodedInputPair,
33
+ PaddingStrategy,
34
+ PreTokenizedInput,
35
+ PreTokenizedInputPair,
36
+ PreTrainedTokenizerBase,
37
+ TensorType,
38
+ TextInput,
39
+ TextInputPair,
40
+ TruncationStrategy,
41
+ )
42
+
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ def _is_whitespace(char):
48
+ """Checks whether `chars` is a whitespace character."""
49
+ # \t, \n, and \r are technically contorl characters but we treat them
50
+ # as whitespace since they are generally considered as such.
51
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
52
+ return True
53
+ cat = unicodedata.category(char)
54
+ if cat == "Zs":
55
+ return True
56
+ return False
57
+
58
+
59
+ def _is_control(char):
60
+ """Checks whether `chars` is a control character."""
61
+ # These are technically control characters but we count them as whitespace
62
+ # characters.
63
+ if char == "\t" or char == "\n" or char == "\r":
64
+ return False
65
+ cat = unicodedata.category(char)
66
+ if cat.startswith("C"):
67
+ return True
68
+ return False
69
+
70
+
71
+ def _is_punctuation(char):
72
+ """Checks whether `chars` is a punctuation character."""
73
+ cp = ord(char)
74
+ # We treat all non-letter/number ASCII as punctuation.
75
+ # Characters such as "^", "$", and "`" are not in the Unicode
76
+ # Punctuation class but we treat them as punctuation anyways, for
77
+ # consistency.
78
+ if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
79
+ return True
80
+ cat = unicodedata.category(char)
81
+ if cat.startswith("P"):
82
+ return True
83
+ return False
84
+
85
+
86
+ def _is_end_of_word(text):
87
+ """Checks whether the last character in text is one of a punctuation, control or whitespace character."""
88
+ last_char = text[-1]
89
+ return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char))
90
+
91
+
92
+ def _is_start_of_word(text):
93
+ """Checks whether the first character in text is one of a punctuation, control or whitespace character."""
94
+ first_char = text[0]
95
+ return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))
96
+
97
+
98
+ class PreTrainedTokenizer(PreTrainedTokenizerBase):
99
+ """ Base class for all slow tokenizers.
100
+
101
+ Handle all the shared methods for tokenization and special tokens as well as methods
102
+ downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
103
+
104
+ This class also contain the added tokens in a unified way on top of all tokenizers so we don't
105
+ have to handle the specific vocabulary augmentation methods of the various underlying
106
+ dictionary structures (BPE, sentencepiece...).
107
+
108
+ Class attributes (overridden by derived classes):
109
+
110
+ - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file
111
+ required by the model, and as associated values, the filename for saving the associated file (string).
112
+ - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys
113
+ being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the
114
+ `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the
115
+ associated pretrained vocabulary file.
116
+ - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained
117
+ models, and as associated values, the maximum length of the sequence inputs of this model, or None if the
118
+ model has no maximum input size.
119
+ - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the
120
+ pretrained models, and as associated values, a dictionnary of specific arguments to pass to the
121
+ ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the
122
+ ``from_pretrained()`` method.
123
+
124
+ Args:
125
+ - ``model_max_length``: (`Optional`) int: the maximum length in number of tokens for the inputs to the transformer model.
126
+ When the tokenizer is loaded with `from_pretrained`, this will be set to the value stored for the associated
127
+ model in ``max_model_input_sizes`` (see above). If no value is provided, will default to VERY_LARGE_INTEGER (`int(1e30)`).
128
+ no associated max_length can be found in ``max_model_input_sizes``.
129
+ - ``padding_side``: (`Optional`) string: the side on which the model should have padding applied.
130
+ Should be selected between ['right', 'left']
131
+ - ``model_input_names``: (`Optional`) List[string]: the list of the forward pass inputs accepted by the
132
+ model ("token_type_ids", "attention_mask"...).
133
+ - ``bos_token``: (`Optional`) string: a beginning of sentence token.
134
+ Will be associated to ``self.bos_token`` and ``self.bos_token_id``
135
+ - ``eos_token``: (`Optional`) string: an end of sentence token.
136
+ Will be associated to ``self.eos_token`` and ``self.eos_token_id``
137
+ - ``unk_token``: (`Optional`) string: an unknown token.
138
+ Will be associated to ``self.unk_token`` and ``self.unk_token_id``
139
+ - ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence).
140
+ Will be associated to ``self.sep_token`` and ``self.sep_token_id``
141
+ - ``pad_token``: (`Optional`) string: a padding token.
142
+ Will be associated to ``self.pad_token`` and ``self.pad_token_id``
143
+ - ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence
144
+ leveraging self-attention along the full depth of the model).
145
+ Will be associated to ``self.cls_token`` and ``self.cls_token_id``
146
+ - ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language
147
+ modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id``
148
+ - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens.
149
+ Adding all special tokens here ensure they won't be split by the tokenization process.
150
+ Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``
151
+
152
+
153
+ .. automethod:: __call__
154
+ """
155
+
156
+ def __init__(self, **kwargs):
157
+ super().__init__(**kwargs)
158
+
159
+ # Added tokens - We store this for both slow and fast tokenizers
160
+ # until the serialization of Fast tokenizers is updated
161
+ self.added_tokens_encoder: Dict[str, int] = {}
162
+ self.added_tokens_decoder: Dict[int, str] = {}
163
+ self.unique_no_split_tokens: List[str] = []
164
+
165
+ @property
166
+ def is_fast(self) -> bool:
167
+ return False
168
+
169
+ @property
170
+ def vocab_size(self) -> int:
171
+ """ Size of the base vocabulary (without the added tokens) """
172
+ raise NotImplementedError
173
+
174
+ def get_vocab(self):
175
+ """ Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """
176
+ raise NotImplementedError()
177
+
178
+ def get_added_vocab(self) -> Dict[str, int]:
179
+ return self.added_tokens_encoder
180
+
181
+ def __len__(self):
182
+ """ Size of the full vocabulary with the added tokens """
183
+ return self.vocab_size + len(self.added_tokens_encoder)
184
+
185
+ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens=False) -> int:
186
+ """
187
+ Add a list of new tokens to the tokenizer class. If the new tokens are not in the
188
+ vocabulary, they are added to it with indices starting from length of the current vocabulary.
189
+
190
+ Args:
191
+ new_tokens: string or list of string. Each string is a token to add. Tokens are only added if they are not
192
+ already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
193
+
194
+ Returns:
195
+ Number of tokens added to the vocabulary.
196
+
197
+ Examples::
198
+
199
+ # Let's see how to increase the vocabulary of Bert model and tokenizer
200
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
201
+ model = BertModel.from_pretrained('bert-base-uncased')
202
+
203
+ num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
204
+ print('We have added', num_added_toks, 'tokens')
205
+ model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
206
+ """
207
+ new_tokens = [str(tok) for tok in new_tokens]
208
+
209
+ tokens_to_add = []
210
+ for token in new_tokens:
211
+ assert isinstance(token, str)
212
+ if not special_tokens and self.init_kwargs.get("do_lower_case", False):
213
+ token = token.lower()
214
+ if (
215
+ token != self.unk_token
216
+ and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
217
+ and token not in tokens_to_add
218
+ ):
219
+ tokens_to_add.append(token)
220
+ if self.verbose:
221
+ logger.info("Adding %s to the vocabulary", token)
222
+
223
+ added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
224
+ added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
225
+ self.added_tokens_encoder.update(added_tok_encoder)
226
+ self.added_tokens_decoder.update(added_tok_decoder)
227
+
228
+ # Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert)
229
+ if special_tokens:
230
+ self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(new_tokens)))
231
+ else:
232
+ # Or on the newly added tokens
233
+ self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
234
+
235
+ return len(tokens_to_add)
236
+
237
+ def num_special_tokens_to_add(self, pair=False):
238
+ """
239
+ Returns the number of added tokens when encoding a sequence with special tokens.
240
+
241
+ Note:
242
+ This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this
243
+ inside your training loop.
244
+
245
+ Args:
246
+ pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the
247
+ number of added tokens in the case of a single sequence if set to False.
248
+
249
+ Returns:
250
+ Number of tokens added to sequences
251
+ """
252
+ token_ids_0 = []
253
+ token_ids_1 = []
254
+ return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
255
+
256
+ def tokenize(self, text: TextInput, **kwargs):
257
+ """ Converts a string in a sequence of tokens (string), using the tokenizer.
258
+ Split in words for word-based vocabulary or sub-words for sub-word-based
259
+ vocabularies (BPE/SentencePieces/WordPieces).
260
+
261
+ Take care of added tokens.
262
+
263
+ Args:
264
+ text (:obj:`string`): The sequence to be encoded.
265
+ **kwargs (:obj: `dict`): Arguments passed to the model-specific `prepare_for_tokenization` preprocessing method.
266
+ """
267
+ # Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
268
+ all_special_tokens_extended = dict(
269
+ (str(t), t) for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
270
+ )
271
+
272
+ text, kwargs = self.prepare_for_tokenization(text, **kwargs)
273
+
274
+ if kwargs:
275
+ logger.warning(f"Keyword arguments {kwargs} not recognized.")
276
+
277
+ # TODO: should this be in the base class?
278
+ if self.init_kwargs.get("do_lower_case", False):
279
+ # convert non-special tokens to lowercase
280
+ escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
281
+ pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
282
+ text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
283
+
284
+ def split_on_token(tok, text):
285
+ result = []
286
+ tok_extended = all_special_tokens_extended.get(tok, None)
287
+ split_text = text.split(tok)
288
+ full_word = ""
289
+ for i, sub_text in enumerate(split_text):
290
+ # AddedToken can control whitespace stripping around them.
291
+ # We use them for GPT2 and Roberta to have different behavior depending on the special token
292
+ # Cf. https://github.com/huggingface/transformers/pull/2778
293
+ # and https://github.com/huggingface/transformers/issues/3788
294
+ if isinstance(tok_extended, AddedToken):
295
+ if tok_extended.single_word:
296
+ # Try to avoid splitting on token
297
+ if (
298
+ i < len(split_text) - 1
299
+ and not _is_end_of_word(sub_text)
300
+ and not _is_start_of_word(split_text[i + 1])
301
+ ):
302
+ # Don't extract the special token
303
+ full_word += sub_text + tok
304
+ elif full_word:
305
+ full_word += sub_text
306
+ result += [full_word]
307
+ full_word = ""
308
+ continue
309
+ # Strip white spaces on the right
310
+ if tok_extended.rstrip and i > 0:
311
+ # A bit counter-intuitive but we strip the left of the string
312
+ # since tok_extended.rstrip means the special token is eating all white spaces on its right
313
+ sub_text = sub_text.lstrip()
314
+ # Strip white spaces on the left
315
+ if tok_extended.lstrip and i < len(split_text) - 1:
316
+ sub_text = sub_text.rstrip() # Opposite here
317
+ else:
318
+ # We strip left and right by default
319
+ if i < len(split_text) - 1:
320
+ sub_text = sub_text.rstrip()
321
+ if i > 0:
322
+ sub_text = sub_text.lstrip()
323
+
324
+ if i == 0 and not sub_text:
325
+ result += [tok]
326
+ elif i == len(split_text) - 1:
327
+ if sub_text:
328
+ result += [sub_text]
329
+ else:
330
+ pass
331
+ else:
332
+ if sub_text:
333
+ result += [sub_text]
334
+ result += [tok]
335
+ return result
336
+
337
+ def split_on_tokens(tok_list, text):
338
+ if not text.strip():
339
+ return []
340
+ if not tok_list:
341
+ return self._tokenize(text)
342
+
343
+ tokenized_text = []
344
+ text_list = [text]
345
+ for tok in tok_list:
346
+ tokenized_text = []
347
+ for sub_text in text_list:
348
+ if sub_text not in self.unique_no_split_tokens:
349
+ tokenized_text += split_on_token(tok, sub_text)
350
+ else:
351
+ tokenized_text += [sub_text]
352
+ text_list = tokenized_text
353
+
354
+ return list(
355
+ itertools.chain.from_iterable(
356
+ (
357
+ self._tokenize(token) if token not in self.unique_no_split_tokens else [token]
358
+ for token in tokenized_text
359
+ )
360
+ )
361
+ )
362
+
363
+ no_split_token = self.unique_no_split_tokens
364
+ tokenized_text = split_on_tokens(no_split_token, text)
365
+ return tokenized_text
366
+
367
+ def _tokenize(self, text, **kwargs):
368
+ """ Converts a string in a sequence of tokens (string), using the tokenizer.
369
+ Split in words for word-based vocabulary or sub-words for sub-word-based
370
+ vocabularies (BPE/SentencePieces/WordPieces).
371
+
372
+ Do NOT take care of added tokens.
373
+ """
374
+ raise NotImplementedError
375
+
376
+ def convert_tokens_to_ids(self, tokens):
377
+ """ Converts a token string (or a sequence of tokens) in a single integer id
378
+ (or a sequence of ids), using the vocabulary.
379
+ """
380
+ if tokens is None:
381
+ return None
382
+
383
+ if isinstance(tokens, str):
384
+ return self._convert_token_to_id_with_added_voc(tokens)
385
+
386
+ ids = []
387
+ for token in tokens:
388
+ ids.append(self._convert_token_to_id_with_added_voc(token))
389
+ return ids
390
+
391
+ def _convert_token_to_id_with_added_voc(self, token):
392
+ if token is None:
393
+ return None
394
+
395
+ if token in self.added_tokens_encoder:
396
+ return self.added_tokens_encoder[token]
397
+ return self._convert_token_to_id(token)
398
+
399
+ def _convert_token_to_id(self, token):
400
+ raise NotImplementedError
401
+
402
+ def _encode_plus(
403
+ self,
404
+ text: Union[TextInput, PreTokenizedInput, EncodedInput],
405
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
406
+ add_special_tokens: bool = True,
407
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
408
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
409
+ max_length: Optional[int] = None,
410
+ stride: int = 0,
411
+ is_pretokenized: bool = False,
412
+ pad_to_multiple_of: Optional[int] = None,
413
+ return_tensors: Optional[Union[str, TensorType]] = None,
414
+ return_token_type_ids: Optional[bool] = None,
415
+ return_attention_mask: Optional[bool] = None,
416
+ return_overflowing_tokens: bool = False,
417
+ return_special_tokens_mask: bool = False,
418
+ return_offsets_mapping: bool = False,
419
+ return_length: bool = False,
420
+ verbose: bool = True,
421
+ **kwargs
422
+ ) -> BatchEncoding:
423
+ def get_input_ids(text):
424
+ if isinstance(text, str):
425
+ tokens = self.tokenize(text, **kwargs)
426
+ return self.convert_tokens_to_ids(tokens)
427
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
428
+ if is_pretokenized:
429
+ tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
430
+ return self.convert_tokens_to_ids(tokens)
431
+ else:
432
+ return self.convert_tokens_to_ids(text)
433
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
434
+ return text
435
+ else:
436
+ if is_pretokenized:
437
+ raise ValueError(
438
+ f"Input {text} is not valid. Should be a string or a list/tuple of strings when `is_pretokenized=True`."
439
+ )
440
+ else:
441
+ raise ValueError(
442
+ f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
443
+ )
444
+
445
+ if return_offsets_mapping:
446
+ raise NotImplementedError(
447
+ "return_offset_mapping is not available when using Python tokenizers."
448
+ "To use this feature, change your tokenizer to one deriving from "
449
+ "transformers.PreTrainedTokenizerFast."
450
+ "More information on available tokenizers at "
451
+ "https://github.com/huggingface/transformers/pull/2674"
452
+ )
453
+
454
+ first_ids = get_input_ids(text)
455
+ second_ids = get_input_ids(text_pair) if text_pair is not None else None
456
+
457
+ return self.prepare_for_model(
458
+ first_ids,
459
+ pair_ids=second_ids,
460
+ add_special_tokens=add_special_tokens,
461
+ padding=padding_strategy.value,
462
+ truncation=truncation_strategy.value,
463
+ max_length=max_length,
464
+ stride=stride,
465
+ pad_to_multiple_of=pad_to_multiple_of,
466
+ return_tensors=return_tensors,
467
+ prepend_batch_axis=True,
468
+ return_attention_mask=return_attention_mask,
469
+ return_token_type_ids=return_token_type_ids,
470
+ return_overflowing_tokens=return_overflowing_tokens,
471
+ return_special_tokens_mask=return_special_tokens_mask,
472
+ return_length=return_length,
473
+ verbose=verbose,
474
+ )
475
+
476
+ def _batch_encode_plus(
477
+ self,
478
+ batch_text_or_text_pairs: Union[
479
+ List[TextInput],
480
+ List[TextInputPair],
481
+ List[PreTokenizedInput],
482
+ List[PreTokenizedInputPair],
483
+ List[EncodedInput],
484
+ List[EncodedInputPair],
485
+ ],
486
+ add_special_tokens: bool = True,
487
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
488
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
489
+ max_length: Optional[int] = None,
490
+ stride: int = 0,
491
+ is_pretokenized: bool = False,
492
+ pad_to_multiple_of: Optional[int] = None,
493
+ return_tensors: Optional[Union[str, TensorType]] = None,
494
+ return_token_type_ids: Optional[bool] = None,
495
+ return_attention_mask: Optional[bool] = None,
496
+ return_overflowing_tokens: bool = False,
497
+ return_special_tokens_mask: bool = False,
498
+ return_offsets_mapping: bool = False,
499
+ return_length: bool = False,
500
+ verbose: bool = True,
501
+ **kwargs
502
+ ) -> BatchEncoding:
503
+ def get_input_ids(text):
504
+ if isinstance(text, str):
505
+ tokens = self.tokenize(text, **kwargs)
506
+ return self.convert_tokens_to_ids(tokens)
507
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
508
+ if is_pretokenized:
509
+ tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
510
+ return self.convert_tokens_to_ids(tokens)
511
+ else:
512
+ return self.convert_tokens_to_ids(text)
513
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
514
+ return text
515
+ else:
516
+ raise ValueError(
517
+ "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
518
+ )
519
+
520
+ if return_offsets_mapping:
521
+ raise NotImplementedError(
522
+ "return_offset_mapping is not available when using Python tokenizers."
523
+ "To use this feature, change your tokenizer to one deriving from "
524
+ "transformers.PreTrainedTokenizerFast."
525
+ )
526
+
527
+ input_ids = []
528
+ for ids_or_pair_ids in batch_text_or_text_pairs:
529
+ if not isinstance(ids_or_pair_ids, (list, tuple)):
530
+ ids, pair_ids = ids_or_pair_ids, None
531
+ elif is_pretokenized and not isinstance(ids_or_pair_ids[0], (list, tuple)):
532
+ ids, pair_ids = ids_or_pair_ids, None
533
+ else:
534
+ ids, pair_ids = ids_or_pair_ids
535
+
536
+ first_ids = get_input_ids(ids)
537
+ second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
538
+ input_ids.append((first_ids, second_ids))
539
+
540
+ batch_outputs = self._batch_prepare_for_model(
541
+ input_ids,
542
+ add_special_tokens=add_special_tokens,
543
+ padding_strategy=padding_strategy,
544
+ truncation_strategy=truncation_strategy,
545
+ max_length=max_length,
546
+ stride=stride,
547
+ pad_to_multiple_of=pad_to_multiple_of,
548
+ return_attention_mask=return_attention_mask,
549
+ return_token_type_ids=return_token_type_ids,
550
+ return_overflowing_tokens=return_overflowing_tokens,
551
+ return_special_tokens_mask=return_special_tokens_mask,
552
+ return_length=return_length,
553
+ return_tensors=return_tensors,
554
+ verbose=verbose,
555
+ )
556
+
557
+ return BatchEncoding(batch_outputs)
558
+
559
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
560
+ def _batch_prepare_for_model(
561
+ self,
562
+ batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]],
563
+ add_special_tokens: bool = True,
564
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
565
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
566
+ max_length: Optional[int] = None,
567
+ stride: int = 0,
568
+ pad_to_multiple_of: Optional[int] = None,
569
+ return_tensors: Optional[str] = None,
570
+ return_token_type_ids: Optional[bool] = None,
571
+ return_attention_mask: Optional[bool] = None,
572
+ return_overflowing_tokens: bool = False,
573
+ return_special_tokens_mask: bool = False,
574
+ return_length: bool = False,
575
+ verbose: bool = True,
576
+ ) -> BatchEncoding:
577
+ """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
578
+ It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
579
+ manages a moving window (with user defined stride) for overflowing tokens
580
+
581
+ Args:
582
+ batch_ids_pairs: list of tokenized input ids or input ids pairs
583
+ """
584
+
585
+ batch_outputs = {}
586
+ for first_ids, second_ids in batch_ids_pairs:
587
+ outputs = self.prepare_for_model(
588
+ first_ids,
589
+ second_ids,
590
+ add_special_tokens=add_special_tokens,
591
+ padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward
592
+ truncation=truncation_strategy.value,
593
+ max_length=max_length,
594
+ stride=stride,
595
+ pad_to_multiple_of=None, # we pad in batch afterward
596
+ return_attention_mask=False, # we pad in batch afterward
597
+ return_token_type_ids=return_token_type_ids,
598
+ return_overflowing_tokens=return_overflowing_tokens,
599
+ return_special_tokens_mask=return_special_tokens_mask,
600
+ return_length=return_length,
601
+ return_tensors=None, # We convert the whole batch to tensors at the end
602
+ prepend_batch_axis=False,
603
+ verbose=verbose,
604
+ )
605
+
606
+ for key, value in outputs.items():
607
+ if key not in batch_outputs:
608
+ batch_outputs[key] = []
609
+ batch_outputs[key].append(value)
610
+
611
+ batch_outputs = self.pad(
612
+ batch_outputs,
613
+ padding=padding_strategy.value,
614
+ max_length=max_length,
615
+ pad_to_multiple_of=pad_to_multiple_of,
616
+ return_attention_mask=return_attention_mask,
617
+ )
618
+
619
+ batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
620
+
621
+ return batch_outputs
622
+
623
+ def prepare_for_tokenization(self, text: str, is_pretokenized=False, **kwargs) -> (str, dict):
624
+ """ Performs any necessary transformations before tokenization.
625
+
626
+ This method should pop the arguments from kwargs and return kwargs as well.
627
+ We test kwargs at the end of the encoding process to be sure all the arguments have been used.
628
+ """
629
+ return (text, kwargs)
630
+
631
+ def get_special_tokens_mask(
632
+ self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
633
+ ) -> List[int]:
634
+ """
635
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
636
+ special tokens using the tokenizer ``prepare_for_model`` method.
637
+
638
+ Args:
639
+ token_ids_0: list of ids (must not contain special tokens)
640
+ token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
641
+ for sequence pairs
642
+ already_has_special_tokens: (default False) Set to True if the token list is already formated with
643
+ special tokens for the model
644
+
645
+ Returns:
646
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
647
+ """
648
+ return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
649
+
650
+ def convert_ids_to_tokens(
651
+ self, ids: Union[int, List[int]], skip_special_tokens: bool = False
652
+ ) -> Union[str, List[str]]:
653
+ """ Converts a single index or a sequence of indices (integers) in a token "
654
+ (resp.) a sequence of tokens (str), using the vocabulary and added tokens.
655
+
656
+ Args:
657
+ skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
658
+ """
659
+ if isinstance(ids, int):
660
+ if ids in self.added_tokens_decoder:
661
+ return self.added_tokens_decoder[ids]
662
+ else:
663
+ return self._convert_id_to_token(ids)
664
+ tokens = []
665
+ for index in ids:
666
+ index = int(index)
667
+ if skip_special_tokens and index in self.all_special_ids:
668
+ continue
669
+ if index in self.added_tokens_decoder:
670
+ tokens.append(self.added_tokens_decoder[index])
671
+ else:
672
+ tokens.append(self._convert_id_to_token(index))
673
+ return tokens
674
+
675
+ def _convert_id_to_token(self, index: int) -> str:
676
+ raise NotImplementedError
677
+
678
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
679
+ """ Converts a sequence of tokens (string) in a single string.
680
+ The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
681
+ but we often want to remove sub-word tokenization artifacts at the same time.
682
+ """
683
+ return " ".join(self.convert_ids_to_tokens(tokens))
684
+
685
+ def decode(
686
+ self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
687
+ ) -> str:
688
+ filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
689
+
690
+ # To avoid mixing byte-level and unicode for byte-level BPT
691
+ # we need to build string separatly for added tokens and byte-level tokens
692
+ # cf. https://github.com/huggingface/transformers/issues/1133
693
+ sub_texts = []
694
+ current_sub_text = []
695
+ for token in filtered_tokens:
696
+ if skip_special_tokens and token in self.all_special_ids:
697
+ continue
698
+ if token in self.added_tokens_encoder:
699
+ if current_sub_text:
700
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
701
+ current_sub_text = []
702
+ sub_texts.append(token)
703
+ else:
704
+ current_sub_text.append(token)
705
+ if current_sub_text:
706
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
707
+ text = " ".join(sub_texts)
708
+
709
+ if clean_up_tokenization_spaces:
710
+ clean_text = self.clean_up_tokenization(text)
711
+ return clean_text
712
+ else:
713
+ return text
714
+
715
+ def save_vocabulary(self, save_directory) -> Tuple[str]:
716
+ """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
717
+ and special token mappings.
718
+
719
+ Please use :func:`~transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full
720
+ Tokenizer state if you want to reload it using the :func:`~transformers.PreTrainedTokenizer.from_pretrained`
721
+ class method.
722
+ """
723
+ raise NotImplementedError
LAVT-RIS/bert/tokenization_utils_base.py ADDED
The diff for this file is too large to render. See raw diff
 
LAVT-RIS/data.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
LAVT-RIS/data/__pycache__/dataset_refer_bert.cpython-39.pyc ADDED
Binary file (6.17 kB). View file
 
LAVT-RIS/data/__pycache__/dataset_refer_bert_mostat.cpython-39.pyc ADDED
Binary file (3.9 kB). View file
 
LAVT-RIS/data/__pycache__/dataset_refer_bert_rev.cpython-39.pyc ADDED
Binary file (6.87 kB). View file
 
LAVT-RIS/data/__pycache__/dataset_refer_zom.cpython-39.pyc ADDED
Binary file (8.63 kB). View file
 
LAVT-RIS/data/dataset_refer_bert.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import torch.utils.data as data
5
+ import torch
6
+ from torchvision import transforms
7
+ import numpy as np
8
+ from PIL import Image
9
+ import torchvision.transforms.functional as TF
10
+ import random
11
+
12
+ from bert.tokenization_bert import BertTokenizer
13
+ from refer.refer import REFER
14
+
15
+ from args import get_parser
16
+
17
+ # Dataset configuration initialization
18
+ # parser = get_parser()
19
+ # args = parser.parse_args()
20
+
21
+
22
+ class ReferDataset(data.Dataset):
23
+
24
+ def __init__(self,
25
+ args,
26
+ image_transforms=None,
27
+ target_transforms=None,
28
+ split='train',
29
+ eval_mode=False):
30
+
31
+ self.classes = []
32
+ self.image_transforms = image_transforms
33
+ self.target_transform = target_transforms
34
+ self.split = split
35
+ self.refer = REFER(args.refer_data_root, args.dataset, args.splitBy)
36
+
37
+ self.max_tokens = 20
38
+
39
+ ref_ids = self.refer.getRefIds(split=self.split)
40
+ img_ids = self.refer.getImgIds(ref_ids)
41
+
42
+ all_imgs = self.refer.Imgs
43
+ self.imgs = list(all_imgs[i] for i in img_ids)
44
+ self.ref_ids = ref_ids
45
+
46
+ self.input_ids = []
47
+ self.attention_masks = []
48
+ self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer)
49
+
50
+ # for metric learning
51
+ self.ROOT = '/data2/dataset/RefCOCO/VRIS'
52
+ self.metric_learning = args.metric_learning
53
+ self.exclude_multiobj = args.exclude_multiobj
54
+ self.metric_mode = args.metric_mode
55
+ self.exclude_position = False
56
+
57
+ if self.metric_learning and eval_mode == False:
58
+ self.hardneg_prob = args.hn_prob
59
+ self.multi_obj_ref_ids = self._load_multi_obj_ref_ids()
60
+ self.hardpos_meta, self.hardneg_meta = self._load_metadata()
61
+ else:
62
+ self.hardneg_prob = 0.0
63
+ self.multi_obj_ref_ids = None
64
+ self.hardpos_meta, self.hardneg_meta = None, None
65
+
66
+
67
+ self.eval_mode = eval_mode
68
+ # if we are testing on a dataset, test all sentences of an object;
69
+ # o/w, we are validating during training, randomly sample one sentence for efficiency
70
+ for r in ref_ids:
71
+ ref = self.refer.Refs[r]
72
+
73
+ sentences_for_ref = []
74
+ attentions_for_ref = []
75
+
76
+ for i, (el, sent_id) in enumerate(zip(ref['sentences'], ref['sent_ids'])):
77
+ sentence_raw = el['raw']
78
+ attention_mask = [0] * self.max_tokens
79
+ padded_input_ids = [0] * self.max_tokens
80
+
81
+ input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True)
82
+
83
+ # truncation of tokens
84
+ input_ids = input_ids[:self.max_tokens]
85
+
86
+ padded_input_ids[:len(input_ids)] = input_ids
87
+ attention_mask[:len(input_ids)] = [1]*len(input_ids)
88
+
89
+ sentences_for_ref.append(torch.tensor(padded_input_ids).unsqueeze(0))
90
+ attentions_for_ref.append(torch.tensor(attention_mask).unsqueeze(0))
91
+
92
+ self.input_ids.append(sentences_for_ref)
93
+ self.attention_masks.append(attentions_for_ref)
94
+
95
+
96
+ def _tokenize(self, sentence):
97
+ attention_mask = [0] * self.max_tokens
98
+ padded_input_ids = [0] * self.max_tokens
99
+
100
+ input_ids = self.tokenizer.encode(text=sentence, add_special_tokens=True)
101
+ # truncation of tokens
102
+ input_ids = input_ids[:self.max_tokens]
103
+ padded_input_ids[:len(input_ids)] = input_ids
104
+ attention_mask[:len(input_ids)] = [1]*len(input_ids)
105
+
106
+ # match shape as (1, max_tokens)
107
+ return torch.tensor(padded_input_ids).unsqueeze(0), torch.tensor(attention_mask).unsqueeze(0)
108
+
109
+
110
+ def _load_multi_obj_ref_ids(self):
111
+ # Load multi-object reference IDs based on configurations
112
+ if not self.exclude_multiobj and not self.exclude_position :
113
+ return None
114
+ elif self.exclude_position:
115
+ multiobj_path = os.path.join(self.ROOT, 'multiobj_ov2_nopos.txt')
116
+ elif self.exclude_multiobj :
117
+ multiobj_path = os.path.join(self.ROOT, 'multiobj_ov3.txt')
118
+ with open(multiobj_path, 'r') as f:
119
+ return [int(line.strip()) for line in f.readlines()]
120
+
121
+ def _load_metadata(self):
122
+ # Load metadata for hard positive verb phrases, hard negative queries
123
+ if 'refined' in self.metric_mode or 'hardneg' in self.metric_mode :
124
+ hardpos_path = os.path.join(self.ROOT, 'hardpos_verdict_gref_v4.json')
125
+ else :
126
+ hardpos_path = os.path.join(self.ROOT, 'hardpos_verbphrase_0906upd.json')
127
+ # do not use hardneg_path
128
+ hardneg_path = os.path.join(self.ROOT, 'hardneg_verb.json')
129
+
130
+ with open(hardpos_path, 'r', encoding='utf-8') as f:
131
+ hardpos_json = json.load(f)
132
+ if "hardpos_only" in self.metric_mode :
133
+ hardneg_json = None
134
+ else :
135
+ with open(hardneg_path, 'r', encoding='utf-8') as q:
136
+ hardneg_json = json.load(q)
137
+ return hardpos_json, hardneg_json
138
+
139
+
140
+ def get_classes(self):
141
+ return self.classes
142
+
143
+ def __len__(self):
144
+ return len(self.ref_ids)
145
+
146
+ def __getitem__(self, index):
147
+ this_ref_id = self.ref_ids[index]
148
+ this_img_id = self.refer.getImgIds(this_ref_id)
149
+ this_img = self.refer.Imgs[this_img_id[0]]
150
+
151
+ IMAGE_DIR = '/data2/dataset/COCO2014/trainval2014/'
152
+ img = Image.open(os.path.join(IMAGE_DIR, this_img['file_name'])).convert("RGB")
153
+
154
+ ref = self.refer.loadRefs(this_ref_id)
155
+ #print(ref)
156
+
157
+ ref_mask = np.array(self.refer.getMask(ref[0])['mask'])
158
+ annot = np.zeros(ref_mask.shape)
159
+ annot[ref_mask == 1] = 1
160
+ annot = Image.fromarray(annot.astype(np.uint8), mode="P")
161
+
162
+ if self.image_transforms is not None:
163
+ # resize, from PIL to tensor, and mean and std normalization
164
+ img, target = self.image_transforms(img, annot)
165
+
166
+ if self.eval_mode:
167
+ embedding = []
168
+ att = []
169
+ for s in range(len(self.input_ids[index])):
170
+ e = self.input_ids[index][s]
171
+ a = self.attention_masks[index][s]
172
+ embedding.append(e.unsqueeze(-1))
173
+ att.append(a.unsqueeze(-1))
174
+
175
+ tensor_embeddings = torch.cat(embedding, dim=-1)
176
+ attention_mask = torch.cat(att, dim=-1)
177
+
178
+ return img, target, tensor_embeddings, attention_mask
179
+
180
+ else: # train phase
181
+ choice_sent = np.random.choice(len(self.input_ids[index]))
182
+ tensor_embeddings = self.input_ids[index][choice_sent]
183
+ attention_mask = self.attention_masks[index][choice_sent]
184
+
185
+ pos_sent = torch.zeros_like(tensor_embeddings)
186
+ neg_sent = torch.zeros_like(tensor_embeddings)
187
+ pos_attn_mask = torch.zeros_like(attention_mask)
188
+ neg_attn_mask = torch.zeros_like(attention_mask)
189
+
190
+ if self.metric_learning:
191
+ if 'hardpos_' in self.metric_mode or self.hardneg_prob == 0.0:
192
+ pos_sents = self.hardpos_meta[str(this_ref_id)].values()
193
+ # drop elements with none
194
+ pos_sents = [s for s in pos_sents if s is not None]
195
+ pos_sent_picked = random.choice(list(pos_sents))
196
+ if pos_sent_picked:
197
+ pos_sent, pos_attn_mask = self._tokenize(pos_sent_picked)
198
+ else:
199
+ pos_sents = self.hardpos_meta[str(this_ref_id)].values()
200
+ # drop elements with none
201
+ pos_sents = [s for s in pos_sents if s is not None]
202
+ pos_sent_picked = random.choice(list(pos_sents))
203
+
204
+ if pos_sent_picked:
205
+ pos_sent, pos_attn_mask = self._tokenize(pos_sent_picked)
206
+
207
+ if random.random() < self.hardneg_prob:
208
+ neg_sents = self.hardneg_meta[str(this_ref_id)].values()
209
+ neg_sents = [s for s in neg_sents if s is not None]
210
+ neg_sent_picked = random.choice(list(neg_sents))
211
+ #print("neg_sent: ", neg_sent)
212
+
213
+ if neg_sent_picked:
214
+ neg_sent, neg_attn_mask = self._tokenize(neg_sent_picked)
215
+
216
+ # print("index: ", self.input_ids[index])
217
+ # print("choice_sent: ", choice_sent)
218
+ # print("tensor_embeddings: ", tensor_embeddings)
219
+ # print("original sentence: ", self.tokenizer.decode(tensor_embeddings.squeeze(0).tolist()))
220
+ # print("pos_sent: ", pos_sent)
221
+ # print("neg_sent: ", neg_sent)
222
+ # print("pos_attn_mask: ", pos_attn_mask)
223
+ # print("neg_attn_mask: ", neg_attn_mask)
224
+
225
+ #exit()
226
+
227
+
228
+ return img, target, tensor_embeddings, attention_mask, pos_sent, pos_attn_mask, neg_sent, neg_attn_mask
LAVT-RIS/data/dataset_refer_bert_mostat.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch.utils.data as data
4
+ import torch
5
+ from torchvision import transforms
6
+ from torch.autograd import Variable
7
+ import numpy as np
8
+ from PIL import Image
9
+ import torchvision.transforms.functional as TF
10
+ import random
11
+
12
+ from bert.tokenization_bert import BertTokenizer
13
+
14
+ import h5py
15
+ from refer.refer import REFER
16
+
17
+ from args import get_parser
18
+
19
+ # Dataset configuration initialization
20
+ parser = get_parser()
21
+ args = parser.parse_args()
22
+
23
+
24
+ class ReferDataset(data.Dataset):
25
+
26
+ def __init__(self,
27
+ args,
28
+ image_transforms=None,
29
+ target_transforms=None,
30
+ split='train',
31
+ eval_mode=False):
32
+
33
+ self.classes = []
34
+ self.image_transforms = image_transforms
35
+ self.target_transform = target_transforms
36
+ self.split = split
37
+ self.dataset = args.dataset
38
+ self.args = args
39
+ if args.dataset == 'refcocog' and args.split in ['motion', 'static']:
40
+ import json
41
+ print(f"Easy & Hard Example Experiments - dataset : {args.dataset}, split : {args.split}")
42
+ if args.split == 'motion' :
43
+ meta_fp = '/data2/projects/chaeyun/LAVT-RIS/test_ablation_motion.json'
44
+ else :
45
+ meta_fp = '/data2/projects/chaeyun/LAVT-RIS/test_ablation_static.json'
46
+
47
+ with open(meta_fp, 'r') as f :
48
+ ref_metas = json.load(f)
49
+
50
+ self.refer = REFER(args.refer_data_root, args.dataset, args.splitBy)
51
+
52
+ self.max_tokens = 20
53
+
54
+ # motion, static split binning
55
+ self.input_ids = []
56
+ self.attention_masks = []
57
+ self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer)
58
+ self.ref_ids = []
59
+ self.eval_mode = eval_mode
60
+ self.refer_ctmz = {}
61
+
62
+ for ref in ref_metas :
63
+ sentences_for_ref = []
64
+ attentions_for_ref = []
65
+ sent_lens_for_ref = []
66
+
67
+ for i, sents in enumerate(ref['sentences']) :
68
+ sentence_raw = sents['sent']
69
+
70
+ attention_mask = [0] * self.max_tokens
71
+ padded_input_ids = [0] * self.max_tokens
72
+ input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True)
73
+ input_ids = input_ids[:self.max_tokens]
74
+
75
+ padded_input_ids[:len(input_ids)] = input_ids
76
+ attention_mask[:len(input_ids)] = [1]*len(input_ids)
77
+
78
+ sentences_for_ref.append(torch.tensor(padded_input_ids).unsqueeze(0))
79
+ attentions_for_ref.append(torch.tensor(attention_mask).unsqueeze(0))
80
+
81
+ self.input_ids.append(sentences_for_ref)
82
+ self.attention_masks.append(attentions_for_ref)
83
+ self.ref_ids.append(ref['segment_id'])
84
+
85
+ if ref['segment_id'] not in self.refer_ctmz :
86
+ self.refer_ctmz[ref['segment_id']] = ref
87
+
88
+ img_ids = self.refer.getImgIds(self.ref_ids)
89
+ all_imgs = self.refer.Imgs
90
+ self.imgs = list(all_imgs[i] for i in img_ids)
91
+
92
+ def get_classes(self):
93
+ return self.classes
94
+
95
+ def __len__(self):
96
+ return len(self.ref_ids)
97
+
98
+ def __getitem__(self, index):
99
+ this_ref_id = self.ref_ids[index]
100
+ this_img_id = self.refer.getImgIds(this_ref_id)
101
+ this_img = self.refer.Imgs[this_img_id[0]]
102
+
103
+ IMAGE_DIR = '/data2/dataset/COCO2014/train2014/'
104
+ img = Image.open(os.path.join(IMAGE_DIR, this_img['file_name'])).convert("RGB")
105
+ ref_orig = self.refer.loadRefs(this_ref_id)
106
+ ref = self.refer_ctmz[this_ref_id]
107
+
108
+ ref_mask = np.array(self.refer.getMask(ref_orig[0])['mask'])
109
+ annot = np.zeros(ref_mask.shape)
110
+
111
+ annot[ref_mask == 1] = 1
112
+ annot = Image.fromarray(annot.astype(np.uint8), mode="P")
113
+
114
+ if self.image_transforms is not None:
115
+ # resize, from PIL to tensor, and mean and std normalization
116
+ img, target = self.image_transforms(img, annot)
117
+
118
+ if self.eval_mode:
119
+ embedding = []
120
+ att = []
121
+ for s in range(len(self.input_ids[index])):
122
+ e = self.input_ids[index][s]
123
+ a = self.attention_masks[index][s]
124
+ embedding.append(e.unsqueeze(-1))
125
+ att.append(a.unsqueeze(-1))
126
+
127
+ tensor_embeddings = torch.cat(embedding, dim=-1)
128
+ attention_mask = torch.cat(att, dim=-1)
129
+
130
+ return img, target, tensor_embeddings, attention_mask
131
+ else:
132
+ choice_sent = np.random.choice(len(self.input_ids[index]))
133
+ tensor_embeddings = self.input_ids[index][choice_sent]
134
+ attention_mask = self.attention_masks[index][choice_sent]
135
+
136
+ return img, target, tensor_embeddings, attention_mask
LAVT-RIS/data/dataset_refer_bert_rev.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import torch.utils.data as data
5
+ import torch
6
+ import itertools
7
+ from torchvision import transforms
8
+ from torch.autograd import Variable
9
+ import numpy as np
10
+ from PIL import Image
11
+ import torchvision.transforms.functional as TF
12
+ import random
13
+
14
+ from bert.tokenization_bert import BertTokenizer
15
+
16
+ import h5py
17
+ from refer.refer import REFER
18
+
19
+ from args import get_parser
20
+
21
+ # Dataset configuration initialization
22
+ # parser = get_parser()
23
+ # args = parser.parse_args()
24
+
25
+
26
+ class ReferDataset(data.Dataset):
27
+
28
+ def __init__(self,
29
+ args,
30
+ image_transforms=None,
31
+ target_transforms=None,
32
+ split='train',
33
+ eval_mode=False):
34
+
35
+ self.classes = []
36
+ self.image_transforms = image_transforms
37
+ self.target_transform = target_transforms
38
+ self.split = split
39
+ self.refer = REFER(args.refer_data_root, args.dataset, args.splitBy)
40
+
41
+ self.max_tokens = 20
42
+
43
+ ref_ids = self.refer.getRefIds(split=self.split)
44
+ img_ids = self.refer.getImgIds(ref_ids)
45
+
46
+ all_imgs = self.refer.Imgs
47
+ self.imgs = list(all_imgs[i] for i in img_ids)
48
+ self.ref_ids = ref_ids
49
+
50
+ self.input_ids = []
51
+ self.attention_masks = []
52
+ self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer)
53
+
54
+ # for metric learning
55
+ self.ROOT = '/data2/projects/seunghoon/VerbRIS/VerbCentric_CY/datasets/VRIS'
56
+ self.metric_learning = args.metric_learning
57
+ self.exclude_multiobj = args.exclude_multiobj
58
+ self.metric_mode = args.metric_mode
59
+ self.exclude_position = False
60
+ self.hp_selection = args.hp_selection
61
+
62
+ if self.metric_learning and eval_mode == False:
63
+ self.hardneg_prob = args.hn_prob
64
+ self.multi_obj_ref_ids = self._load_multi_obj_ref_ids()
65
+ self.hardpos_meta, self.hardneg_meta = self._load_metadata()
66
+ else:
67
+ self.hardneg_prob = 0.0
68
+ self.multi_obj_ref_ids = None
69
+ self.hardpos_meta, self.hardneg_meta = None, None
70
+
71
+
72
+ self.eval_mode = eval_mode
73
+ # if we are testing on a dataset, test all sentences of an object;
74
+ # o/w, we are validating during training, randomly sample one sentence for efficiency
75
+ for r in ref_ids:
76
+ ref = self.refer.Refs[r]
77
+
78
+ sentences_for_ref = []
79
+ attentions_for_ref = []
80
+
81
+ for i, (el, sent_id) in enumerate(zip(ref['sentences'], ref['sent_ids'])):
82
+ sentence_raw = el['raw']
83
+ attention_mask = [0] * self.max_tokens
84
+ padded_input_ids = [0] * self.max_tokens
85
+
86
+ input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True)
87
+
88
+ # truncation of tokens
89
+ input_ids = input_ids[:self.max_tokens]
90
+
91
+ padded_input_ids[:len(input_ids)] = input_ids
92
+ attention_mask[:len(input_ids)] = [1]*len(input_ids)
93
+
94
+ sentences_for_ref.append(torch.tensor(padded_input_ids).unsqueeze(0))
95
+ attentions_for_ref.append(torch.tensor(attention_mask).unsqueeze(0))
96
+
97
+ self.input_ids.append(sentences_for_ref)
98
+ self.attention_masks.append(attentions_for_ref)
99
+
100
+
101
+ def _tokenize(self, sentence):
102
+ attention_mask = [0] * self.max_tokens
103
+ padded_input_ids = [0] * self.max_tokens
104
+
105
+ input_ids = self.tokenizer.encode(text=sentence, add_special_tokens=True)
106
+ # truncation of tokens
107
+ input_ids = input_ids[:self.max_tokens]
108
+ padded_input_ids[:len(input_ids)] = input_ids
109
+ attention_mask[:len(input_ids)] = [1]*len(input_ids)
110
+
111
+ # match shape as (1, max_tokens)
112
+ return torch.tensor(padded_input_ids).unsqueeze(0), torch.tensor(attention_mask).unsqueeze(0)
113
+
114
+
115
+ def _load_multi_obj_ref_ids(self):
116
+ # Load multi-object reference IDs based on configurations
117
+ if not self.exclude_multiobj and not self.exclude_position :
118
+ return None
119
+ elif self.exclude_position:
120
+ multiobj_path = os.path.join(self.ROOT, 'multiobj_ov2_nopos.txt')
121
+ elif self.exclude_multiobj :
122
+ multiobj_path = os.path.join(self.ROOT, 'multiobj_ov3.txt')
123
+ with open(multiobj_path, 'r') as f:
124
+ return [int(line.strip()) for line in f.readlines()]
125
+
126
+ def _load_metadata(self):
127
+ # Load metadata for hard positive verb phrases, hard negative queries
128
+ if 'refined' in self.metric_mode or 'hardneg' in self.metric_mode :
129
+ hardpos_path = os.path.join(self.ROOT, 'hardpos_verdict_gref_v4.json')
130
+ else :
131
+ hardpos_path = os.path.join(self.ROOT, 'hardpos_verbphrase_0906upd.json')
132
+ # do not use hardneg_path
133
+ hardneg_path = os.path.join(self.ROOT, 'hardneg_verb.json')
134
+
135
+ with open(hardpos_path, 'r', encoding='utf-8') as f:
136
+ hardpos_json = json.load(f)
137
+ if "hardpos_only" in self.metric_mode :
138
+ hardneg_json = None
139
+ else :
140
+ with open(hardneg_path, 'r', encoding='utf-8') as q:
141
+ hardneg_json = json.load(q)
142
+ return hardpos_json, hardneg_json
143
+
144
+
145
+ def _get_hardpos_verb(self, ref, seg_id, sent_idx) :
146
+ if seg_id in self.multi_obj_ref_ids:
147
+ return ''
148
+
149
+ # Extract metadata for hard positives if present
150
+ hardpos_dict = self.hardpos_meta.get(str(seg_id), {})
151
+ if self.hp_selection == 'strict' :
152
+ sent_id_list = list(hardpos_dict.keys())
153
+ cur_hardpos = hardpos_dict.get(sent_id_list[sent_idx], {}).get('phrases', [])
154
+ else :
155
+ cur_hardpos = list(itertools.chain.from_iterable(hardpos_dict[sid]['phrases'] for sid in hardpos_dict))
156
+
157
+ if cur_hardpos:
158
+ # Assign a hard positive verb phrase if available
159
+ raw_verb = random.choice(cur_hardpos)
160
+ return raw_verb
161
+
162
+ return ''
163
+
164
+
165
+ def get_classes(self):
166
+ return self.classes
167
+
168
+ def __len__(self):
169
+ return len(self.ref_ids)
170
+
171
+ def __getitem__(self, index):
172
+ this_ref_id = self.ref_ids[index]
173
+ this_img_id = self.refer.getImgIds(this_ref_id)
174
+ this_img = self.refer.Imgs[this_img_id[0]]
175
+
176
+ IMAGE_DIR = '/data2/dataset/COCO2014/trainval2014/'
177
+ img = Image.open(os.path.join(IMAGE_DIR, this_img['file_name'])).convert("RGB")
178
+ ref = self.refer.loadRefs(this_ref_id)
179
+
180
+ ref_mask = np.array(self.refer.getMask(ref[0])['mask'])
181
+ annot = np.zeros(ref_mask.shape)
182
+ annot[ref_mask == 1] = 1
183
+
184
+ annot = Image.fromarray(annot.astype(np.uint8), mode="P")
185
+
186
+ if self.image_transforms is not None:
187
+ # resize, from PIL to tensor, and mean and std normalization
188
+ img, target = self.image_transforms(img, annot)
189
+
190
+ if self.eval_mode:
191
+ embedding = []
192
+ att = []
193
+ for s in range(len(self.input_ids[index])):
194
+ e = self.input_ids[index][s]
195
+ a = self.attention_masks[index][s]
196
+ embedding.append(e.unsqueeze(-1))
197
+ att.append(a.unsqueeze(-1))
198
+
199
+ tensor_embeddings = torch.cat(embedding, dim=-1)
200
+ attention_mask = torch.cat(att, dim=-1)
201
+
202
+ return img, target, tensor_embeddings, attention_mask
203
+
204
+ else: # train phase
205
+ choice_sent = np.random.choice(len(self.input_ids[index]))
206
+ tensor_embeddings = self.input_ids[index][choice_sent]
207
+ attention_mask = self.attention_masks[index][choice_sent]
208
+
209
+ if self.metric_learning:
210
+ pos_sent = torch.zeros_like(tensor_embeddings)
211
+ pos_attn_mask = torch.zeros_like(attention_mask)
212
+
213
+ if 'hardpos_' in self.metric_mode or self.hardneg_prob == 0.0:
214
+ if 'refined' in self.metric_mode :
215
+ pos_sent_picked = self._get_hardpos_verb(ref, this_ref_id, choice_sent)
216
+ else :
217
+ pos_sents = self.hardpos_meta[str(this_ref_id)].values()
218
+ # drop elements with none
219
+ pos_sents = [s for s in pos_sents if s is not None]
220
+ pos_sent_picked = random.choice(list(pos_sents))
221
+ if pos_sent_picked:
222
+ pos_sent, pos_attn_mask = self._tokenize(pos_sent_picked)
223
+
224
+ return img, target, tensor_embeddings, attention_mask, pos_sent, pos_attn_mask
225
+ else:
226
+ neg_sent = torch.zeros_like(tensor_embeddings)
227
+ neg_attn_mask = torch.zeros_like(attention_mask)
228
+
229
+ pos_sents = self.hardpos_meta[str(this_ref_id)].values()
230
+ # drop elements with none
231
+ pos_sents = [s for s in pos_sents if s is not None]
232
+ pos_sent_picked = random.choice(list(pos_sents))
233
+
234
+ if pos_sent_picked:
235
+ pos_sent, pos_attn_mask = self._tokenize(pos_sent_picked)
236
+
237
+ if random.random() < self.hardneg_prob:
238
+ neg_sents = self.hardneg_meta[str(this_ref_id)].values()
239
+ neg_sents = [s for s in neg_sents if s is not None]
240
+ neg_sent_picked = random.choice(list(neg_sents))
241
+ #print("neg_sent: ", neg_sent)
242
+
243
+ if neg_sent_picked:
244
+ neg_sent, neg_attn_mask = self._tokenize(neg_sent_picked)
245
+
246
+ return img, target, tensor_embeddings, attention_mask, pos_sent, pos_attn_mask, neg_sent, neg_attn_mask
LAVT-RIS/data/dataset_refer_zom.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import torch.utils.data as data
5
+ import torch
6
+ import itertools
7
+ import numpy as np
8
+ from PIL import Image
9
+ import pdb
10
+ import copy
11
+ from random import choice
12
+ from bert.tokenization_bert import BertTokenizer
13
+
14
+ from refer.refer_zom import ZREFER
15
+ import copy
16
+ import random
17
+ import torch
18
+ from collections import defaultdict
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+ from torch.utils.data.distributed import DistributedSampler
23
+
24
+ from args import get_parser
25
+ import random
26
+ # Dataset configuration initialization
27
+ parser = get_parser()
28
+ args = parser.parse_args()
29
+
30
+
31
+ class Referzom_Dataset(data.Dataset):
32
+
33
+ def __init__(self,
34
+ args,
35
+ image_transforms=None,
36
+ target_transforms=None,
37
+ split='train',
38
+ eval_mode=False):
39
+
40
+ self.classes = []
41
+ self.image_transforms = image_transforms
42
+ self.target_transform = target_transforms
43
+ self.split = split
44
+ self.refer = ZREFER(args.refer_data_root, args.dataset, args.splitBy)
45
+ self.dataset_type = args.dataset
46
+ self.max_tokens = 20
47
+ ref_ids = self.refer.getRefIds(split=self.split)
48
+ self.img_ids = self.refer.getImgIds(ref_ids)
49
+
50
+ all_imgs = self.refer.Imgs
51
+ self.imgs = list(all_imgs[i] for i in self.img_ids)
52
+ self.ref_ids = ref_ids
53
+
54
+ self.input_ids = []
55
+ self.attention_masks = []
56
+ self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer)
57
+
58
+ self.ROOT = '/data2/dataset/RefCOCO/VRIS'
59
+ self.metric_learning = args.metric_learning
60
+ self.exclude_multiobj = args.exclude_multiobj
61
+ self.metric_mode = args.metric_mode
62
+ self.exclude_position = False
63
+
64
+ if self.metric_learning and eval_mode == False:
65
+ self.hardneg_prob = args.hn_prob
66
+ self.multi_obj_ref_ids = self._load_multi_obj_ref_ids()
67
+ self.hardpos_meta, self.hardneg_meta = self._load_metadata()
68
+ else:
69
+ self.hardneg_prob = 0.0
70
+ self.multi_obj_ref_ids = None
71
+ self.hardpos_meta, self.hardneg_meta = None, None
72
+
73
+ self.eval_mode = eval_mode
74
+
75
+ self.zero_sent_id_list = []
76
+ self.one_sent_id_list = []
77
+ self.all_sent_id_list = []
78
+ self.sent_2_refid = {}
79
+
80
+
81
+ for r in ref_ids:
82
+ ref = self.refer.loadRefs(r)
83
+ source_type = ref[0]['source']
84
+
85
+ for sent_dict in ref[0]['sentences']:
86
+ sent_id = sent_dict['sent_id']
87
+
88
+ self.sent_2_refid[sent_id] = r
89
+ self.all_sent_id_list.append(sent_id)
90
+ if source_type=='zero':
91
+ self.zero_sent_id_list.append(sent_id)
92
+ else:
93
+ self.one_sent_id_list.append(sent_id)
94
+
95
+ for r in ref_ids:
96
+ ref = self.refer.Refs[r]
97
+
98
+ sentences_for_ref = []
99
+ attentions_for_ref = []
100
+
101
+ for i, el in enumerate(ref['sentences']):
102
+ sentence_raw = el['raw']
103
+ attention_mask = [0] * self.max_tokens
104
+ padded_input_ids = [0] * self.max_tokens
105
+
106
+ input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True)
107
+
108
+ # truncation of tokens
109
+ input_ids = input_ids[:self.max_tokens]
110
+
111
+ padded_input_ids[:len(input_ids)] = input_ids
112
+ attention_mask[:len(input_ids)] = [1]*len(input_ids)
113
+
114
+ sentences_for_ref.append(torch.tensor(padded_input_ids).unsqueeze(0))
115
+ attentions_for_ref.append(torch.tensor(attention_mask).unsqueeze(0))
116
+
117
+ self.input_ids.extend(sentences_for_ref)
118
+ self.attention_masks.extend(attentions_for_ref)
119
+
120
+
121
+ def get_classes(self):
122
+ return self.classes
123
+
124
+
125
+ def _tokenize(self, sentence):
126
+ attention_mask = [0] * self.max_tokens
127
+ padded_input_ids = [0] * self.max_tokens
128
+
129
+ input_ids = self.tokenizer.encode(text=sentence, add_special_tokens=True)
130
+ # truncation of tokens
131
+ input_ids = input_ids[:self.max_tokens]
132
+ padded_input_ids[:len(input_ids)] = input_ids
133
+ attention_mask[:len(input_ids)] = [1]*len(input_ids)
134
+
135
+ # match shape as (1, max_tokens)
136
+ return torch.tensor(padded_input_ids).unsqueeze(0), torch.tensor(attention_mask).unsqueeze(0)
137
+
138
+ def _load_multi_obj_ref_ids(self):
139
+ # Load multi-object reference IDs based on configurations
140
+ if not self.exclude_multiobj and not self.exclude_position :
141
+ return None
142
+ elif self.exclude_position:
143
+ multiobj_path = os.path.join(self.ROOT, 'multiobj_ov2_nopos.txt')
144
+ elif self.exclude_multiobj :
145
+ multiobj_path = os.path.join(self.ROOT, 'multiobj_ov3.txt')
146
+ with open(multiobj_path, 'r') as f:
147
+ return [int(line.strip()) for line in f.readlines()]
148
+
149
+ def _load_metadata(self):
150
+ hardpos_path = os.path.join(self.ROOT, 'verb_ext_text_example_refzom.json')
151
+ with open(hardpos_path, 'r', encoding='utf-8') as f:
152
+ hardpos_json = json.load(f)
153
+ if "hardpos_only" in self.metric_mode :
154
+ hardneg_json = None
155
+ # else :
156
+ # hardneg_path = os.path.join(self.ROOT, 'hardneg_verb.json')
157
+ # with open(hardneg_path, 'r', encoding='utf-8') as q:
158
+ # hardneg_json = json.load(q)
159
+ return hardpos_json, hardneg_json
160
+
161
+
162
+ def _get_hardpos_verb(self, ref, seg_id, sent_idx) :
163
+ if seg_id in self.multi_obj_ref_ids:
164
+ return ''
165
+
166
+ # Extract metadata for hard positives if present
167
+ hardpos_dict = self.hardpos_meta.get(str(seg_id), {})
168
+ if self.hp_selection == 'strict' :
169
+ sent_id_list = list(hardpos_dict.keys())
170
+ cur_hardpos = hardpos_dict.get(sent_id_list[sent_idx], {}).get('phrases', [])
171
+ else :
172
+ cur_hardpos = list(itertools.chain.from_iterable(hardpos_dict[sid]['phrases'] for sid in hardpos_dict))
173
+
174
+ if cur_hardpos:
175
+ # Assign a hard positive verb phrase if available
176
+ raw_verb = random.choice(cur_hardpos)
177
+ return raw_verb
178
+
179
+ return ''
180
+
181
+ def __len__(self):
182
+ return len(self.all_sent_id_list)
183
+
184
+ def __getitem__(self, index):
185
+
186
+ sent_id = self.all_sent_id_list[index]
187
+ this_ref_id = self.sent_2_refid[sent_id]
188
+
189
+ this_img_id = self.refer.getImgIds(this_ref_id)
190
+ this_img = self.refer.Imgs[this_img_id[0]]
191
+
192
+ IMAGE_DIR = '/data2/dataset/COCO2014/trainval2014/'
193
+ img = Image.open(os.path.join(IMAGE_DIR, this_img['file_name'])).convert("RGB")
194
+
195
+ ref = self.refer.loadRefs(this_ref_id)
196
+ if self.dataset_type == 'ref-zom':
197
+ source_type = ref[0]['source']
198
+ else:
199
+ source_type = 'not_zero'
200
+
201
+ ref_mask = np.array(self.refer.getMask(ref[0])['mask'])
202
+ annot = np.zeros(ref_mask.shape)
203
+ annot[ref_mask == 1] = 1
204
+ annot = Image.fromarray(annot.astype(np.uint8), mode="P")
205
+
206
+
207
+ if self.image_transforms is not None:
208
+ img, target = self.image_transforms(img, annot)
209
+
210
+ if self.eval_mode:
211
+ embedding = []
212
+ att = []
213
+ for s in range(len(self.input_ids[index])):
214
+ padded_input_ids = self.input_ids[index][s]
215
+ attention_mask = self.attention_masks[index][s]
216
+
217
+ embedding.append(padded_input_ids.unsqueeze(-1))
218
+ att.append(attention_mask.unsqueeze(-1))
219
+
220
+ tensor_embeddings = torch.cat(embedding, dim=-1)
221
+ attention_mask = torch.cat(att, dim=-1)
222
+ return img, target, source_type, tensor_embeddings, attention_mask
223
+
224
+ else:
225
+ choice_sent = np.random.choice(len(self.input_ids[index]))
226
+ tensor_embeddings = self.input_ids[index][choice_sent]
227
+ attention_mask = self.attention_masks[index][choice_sent]
228
+
229
+ if self.metric_learning :
230
+ pos_sent = torch.zeros_like(tensor_embeddings)
231
+ pos_attn_mask = torch.zeros_like(attention_mask)
232
+
233
+ ## Only the case with hardpos_ in metric_mode
234
+ if 'hardpos_' in self.metric_mode or self.hardneg_prob == 0.0:
235
+ pos_type = 'zero'
236
+ if 'refined' in self.metric_mode :
237
+ pos_sent_picked = self._get_hardpos_verb(ref, this_ref_id, choice_sent)
238
+ else :
239
+ pos_sents = self.hardpos_meta[str(this_ref_id)].values()
240
+ # drop elements with none
241
+ pos_sents = [s for s in pos_sents if s is not None]
242
+ pos_sent_picked = random.choice(list(pos_sents))
243
+ if pos_sent_picked :
244
+ pos_type = 'hardpos'
245
+ pos_sent, pos_attn_mask = self._tokenize(pos_sent_picked)
246
+ pos_sent = pos_sent.squeeze(0) if pos_sent.dim() == 2 and pos_sent.size(0) == 1 else pos_sent
247
+ pos_attn_mask = pos_attn_mask.squeeze(0) if pos_attn_mask.size(0) == 1 else pos_attn_mask
248
+
249
+ return img, target, source_type, tensor_embeddings, attention_mask, pos_sent, pos_attn_mask, pos_type
250
+
251
+ return img, target, source_type, tensor_embeddings, attention_mask
252
+
253
+
254
+
255
+
256
+ class Refzom_DistributedSampler(DistributedSampler):
257
+ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
258
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
259
+ self.one_id_list = dataset.one_sent_id_list
260
+
261
+ self.zero_id_list = dataset.zero_sent_id_list
262
+ self.sent_ids_list = dataset.all_sent_id_list
263
+ if self.shuffle==True:
264
+ random.shuffle(self.one_id_list)
265
+ random.shuffle(self.zero_id_list)
266
+
267
+ self.sent_id = self.insert_evenly(self.zero_id_list,self.one_id_list)
268
+ self.indices = self.get_positions(self.sent_ids_list, self.sent_id)
269
+
270
+ def get_positions(self, list_a, list_b):
271
+ position_dict = {value: index for index, value in enumerate(list_a)}
272
+ positions = [position_dict[item] for item in list_b]
273
+
274
+ return positions
275
+
276
+ def insert_evenly(self, list_a, list_b):
277
+ len_a = len(list_a)
278
+ len_b = len(list_b)
279
+ block_size = len_b // len_a
280
+
281
+ result = []
282
+ for i in range(len_a):
283
+ start = i * block_size
284
+ end = (i + 1) * block_size
285
+ result.extend(list_b[start:end])
286
+ result.append(list_a[i])
287
+
288
+ remaining = list_b[(len_a * block_size):]
289
+ result.extend(remaining)
290
+
291
+ return result
292
+
293
+ def __iter__(self):
294
+
295
+ indices_per_process = self.indices[self.rank::self.num_replicas]
296
+ return iter(indices_per_process)
LAVT-RIS/datagen.txt ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [I 2024-12-15 13:30:17.641 ServerApp] jupyter_lsp | extension was successfully linked.
2
+ [I 2024-12-15 13:30:17.652 ServerApp] jupyter_server_terminals | extension was successfully linked.
3
+ [I 2024-12-15 13:30:17.663 ServerApp] jupyterlab | extension was successfully linked.
4
+ [W 2024-12-15 13:30:17.687 JupyterNotebookApp] 'password' has moved from NotebookApp to ServerApp. This config will be passed to ServerApp. Be sure to update your config before our next release.
5
+ [W 2024-12-15 13:30:17.690 ServerApp] ServerApp.password config is deprecated in 2.0. Use PasswordIdentityProvider.hashed_password.
6
+ [I 2024-12-15 13:30:17.691 ServerApp] notebook | extension was successfully linked.
7
+ [I 2024-12-15 13:30:21.701 ServerApp] notebook_shim | extension was successfully linked.
8
+ [I 2024-12-15 13:30:21.734 ServerApp] notebook_shim | extension was successfully loaded.
9
+ [I 2024-12-15 13:30:21.737 ServerApp] jupyter_lsp | extension was successfully loaded.
10
+ [I 2024-12-15 13:30:21.738 ServerApp] jupyter_server_terminals | extension was successfully loaded.
11
+ [I 2024-12-15 13:30:21.745 LabApp] JupyterLab extension loaded from /home/chaeyun/.conda/envs/cris/lib/python3.9/site-packages/jupyterlab
12
+ [I 2024-12-15 13:30:21.745 LabApp] JupyterLab application directory is /data/conda_envs/chaeyun/envs/cris/share/jupyter/lab
13
+ [I 2024-12-15 13:30:21.746 LabApp] Extension Manager is 'pypi'.
14
+ [I 2024-12-15 13:30:21.846 ServerApp] jupyterlab | extension was successfully loaded.
15
+ [I 2024-12-15 13:30:21.856 ServerApp] notebook | extension was successfully loaded.
16
+ [I 2024-12-15 13:30:21.859 ServerApp] Serving notebooks from local directory: /data2/projects/chaeyun/LAVT-RIS
17
+ [I 2024-12-15 13:30:21.859 ServerApp] Jupyter Server 2.14.2 is running at:
18
+ [I 2024-12-15 13:30:21.859 ServerApp] http://localhost:9821/tree
19
+ [I 2024-12-15 13:30:21.859 ServerApp] http://127.0.0.1:9821/tree
20
+ [I 2024-12-15 13:30:21.859 ServerApp] Use Control-C to stop this server and shut down all kernels (twice to skip confirmation).
21
+ [I 2024-12-15 13:30:23.433 ServerApp] Skipped non-installed server(s): bash-language-server, dockerfile-language-server-nodejs, javascript-typescript-langserver, jedi-language-server, julia-language-server, pyright, python-language-server, python-lsp-server, r-languageserver, sql-language-server, texlab, typescript-language-server, unified-language-server, vscode-css-languageserver-bin, vscode-html-languageserver-bin, vscode-json-languageserver-bin, yaml-language-server
22
+ [W 2024-12-15 13:30:53.210 ServerApp] 404 GET /t/hub/api (@::1) 546.55ms referer=None
23
+ [W 2024-12-15 13:30:53.218 ServerApp] 404 GET /t/tree? (@::1) 1.17ms referer=None
24
+ [W 2024-12-15 13:30:55.469 ServerApp] 404 GET /t/api/kernelspecs?1734237055439 (@::1) 12.74ms referer=None
25
+ [W 2024-12-15 13:30:55.470 ServerApp] 404 GET /t/api/kernels?1734237055443 (@::1) 13.42ms referer=None
26
+ [W 2024-12-15 13:30:55.471 ServerApp] 404 GET /t/api/kernels?1734237055444 (@::1) 14.00ms referer=None
27
+ [W 2024-12-15 13:30:55.472 ServerApp] 404 GET /t/api/sessions?1734237055445 (@::1) 0.90ms referer=None
28
+ [W 2024-12-15 13:30:55.479 ServerApp] 404 GET /t/api/kernelspecs?1734237055472 (@::1) 1.03ms referer=None
29
+ [W 2024-12-15 13:30:56.573 ServerApp] 404 GET /hub/api (@::1) 17.58ms referer=None
30
+ [I 2024-12-15 13:30:56.578 JupyterNotebookApp] 302 GET /tree? (@::1) 0.64ms
31
+ [I 2024-12-15 13:30:58.280 ServerApp] User 7d210c8c6a6c435f8810d3e47520e8aa logged in.
32
+ [I 2024-12-15 13:30:58.281 ServerApp] 302 POST /login? (7d210c8c6a6c435f8810d3e47520e8aa@::1) 1.42ms
33
+ [I 2024-12-15 13:31:14.088 ServerApp] Creating new notebook in
34
+ [I 2024-12-15 13:31:15.719 ServerApp] Kernel started: dce7adc9-1c36-40aa-b127-027de8c4cc1d
35
+ [W 2024-12-15 13:31:15.740 ServerApp] delete /angle_vis-jvsc-275d6f36-b360-4853-b869-423f53ae33e58f8ea480-b4de-4dbf-8451-c534a31a02d6.ipynb
36
+ [I 2024-12-15 13:31:44.218 ServerApp] Connecting to kernel dce7adc9-1c36-40aa-b127-027de8c4cc1d.
37
+ [W 2024-12-15 13:31:48.296 ServerApp] 404 GET /nbextensions/jupyter-js-widgets/extension.js (@::1) 43.15ms referer=None
38
+ [W 2024-12-15 13:31:48.297 ServerApp] 404 GET /nbextensions/viewer/extension.js (@::1) 43.91ms referer=None
39
+ [I 2024-12-15 13:36:00.460 ServerApp] Starting buffering for dce7adc9-1c36-40aa-b127-027de8c4cc1d:a783cc47-773e-4781-be41-f059323b1741
40
+ [I 2024-12-15 13:36:14.286 ServerApp] Creating new notebook in
41
+ [I 2024-12-15 13:36:15.676 ServerApp] Kernel started: a6461fc1-58b7-4ec4-9a62-e5e245f3fb94
42
+ [W 2024-12-15 13:36:15.699 ServerApp] delete /angle_vis-jvsc-3fcaa7d2-467e-4c6a-982c-7ecc8685d6df4ba96f03-8186-48f0-b8c0-5c6555991454.ipynb
43
+ [I 2024-12-15 13:36:34.968 ServerApp] Connecting to kernel a6461fc1-58b7-4ec4-9a62-e5e245f3fb94.
44
+ [W 2024-12-15 13:36:38.417 ServerApp] 404 GET /nbextensions/jupyter-js-widgets/extension.js (@::1) 348.55ms referer=None
45
+ [W 2024-12-15 13:36:38.418 ServerApp] 404 GET /nbextensions/viewer/extension.js (@::1) 349.85ms referer=None
46
+ srun: Job step aborted: Waiting up to 32 seconds for job step to finish.
47
+ slurmstepd-node03: error: *** STEP 23990.0 ON node03 CANCELLED AT 2024-12-15T14:08:56 ***
48
+ slurmstepd-node03: error: *** JOB 23990 ON node03 CANCELLED AT 2024-12-15T14:08:56 ***
49
+ [C 2024-12-15 14:08:56.976 ServerApp] received signal 15, stopping
LAVT-RIS/demo_inference.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_path = './demo/demo.jpg'
2
+ sentence = 'the most handsome guy'
3
+ weights = './checkpoints/refcoco.pth'
4
+ device = 'cuda:0'
5
+
6
+ # pre-process the input image
7
+ from PIL import Image
8
+ import torchvision.transforms as T
9
+ import numpy as np
10
+ img = Image.open(image_path).convert("RGB")
11
+ img_ndarray = np.array(img) # (orig_h, orig_w, 3); for visualization
12
+ original_w, original_h = img.size # PIL .size returns width first and height second
13
+
14
+ image_transforms = T.Compose(
15
+ [
16
+ T.Resize(480),
17
+ T.ToTensor(),
18
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
19
+ ]
20
+ )
21
+
22
+ img = image_transforms(img).unsqueeze(0) # (1, 3, 480, 480)
23
+ img = img.to(device) # for inference (input)
24
+
25
+ # pre-process the raw sentence
26
+ from bert.tokenization_bert import BertTokenizer
27
+ import torch
28
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
29
+ sentence_tokenized = tokenizer.encode(text=sentence, add_special_tokens=True)
30
+ sentence_tokenized = sentence_tokenized[:20] # if the sentence is longer than 20, then this truncates it to 20 words
31
+ # pad the tokenized sentence
32
+ padded_sent_toks = [0] * 20
33
+ padded_sent_toks[:len(sentence_tokenized)] = sentence_tokenized
34
+ # create a sentence token mask: 1 for real words; 0 for padded tokens
35
+ attention_mask = [0] * 20
36
+ attention_mask[:len(sentence_tokenized)] = [1]*len(sentence_tokenized)
37
+ # convert lists to tensors
38
+ padded_sent_toks = torch.tensor(padded_sent_toks).unsqueeze(0) # (1, 20)
39
+ attention_mask = torch.tensor(attention_mask).unsqueeze(0) # (1, 20)
40
+ padded_sent_toks = padded_sent_toks.to(device) # for inference (input)
41
+ attention_mask = attention_mask.to(device) # for inference (input)
42
+
43
+ # initialize model and load weights
44
+ from bert.modeling_bert import BertModel
45
+ from lib import segmentation
46
+
47
+ # construct a mini args class; like from a config file
48
+
49
+
50
+ class args:
51
+ swin_type = 'base'
52
+ window12 = True
53
+ mha = ''
54
+ fusion_drop = 0.0
55
+
56
+
57
+ single_model = segmentation.__dict__['lavt'](pretrained='', args=args)
58
+ single_model.to(device)
59
+ model_class = BertModel
60
+ single_bert_model = model_class.from_pretrained('bert-base-uncased')
61
+ single_bert_model.pooler = None
62
+
63
+ checkpoint = torch.load(weights, map_location='cpu')
64
+ single_bert_model.load_state_dict(checkpoint['bert_model'])
65
+ single_model.load_state_dict(checkpoint['model'])
66
+ model = single_model.to(device)
67
+ bert_model = single_bert_model.to(device)
68
+
69
+
70
+ # inference
71
+ import torch.nn.functional as F
72
+ last_hidden_states = bert_model(padded_sent_toks, attention_mask=attention_mask)[0]
73
+ embedding = last_hidden_states.permute(0, 2, 1)
74
+ output = model(img, embedding, l_mask=attention_mask.unsqueeze(-1))
75
+ output = output.argmax(1, keepdim=True) # (1, 1, 480, 480)
76
+ output = F.interpolate(output.float(), (original_h, original_w)) # 'nearest'; resize to the original image size
77
+ output = output.squeeze() # (orig_h, orig_w)
78
+ output = output.cpu().data.numpy() # (orig_h, orig_w)
79
+
80
+
81
+ # show/save results
82
+ def overlay_davis(image, mask, colors=[[0, 0, 0], [255, 0, 0]], cscale=1, alpha=0.4):
83
+ from scipy.ndimage.morphology import binary_dilation
84
+
85
+ colors = np.reshape(colors, (-1, 3))
86
+ colors = np.atleast_2d(colors) * cscale
87
+
88
+ im_overlay = image.copy()
89
+ object_ids = np.unique(mask)
90
+
91
+ for object_id in object_ids[1:]:
92
+ # Overlay color on binary mask
93
+ foreground = image*alpha + np.ones(image.shape)*(1-alpha) * np.array(colors[object_id])
94
+ binary_mask = mask == object_id
95
+
96
+ # Compose image
97
+ im_overlay[binary_mask] = foreground[binary_mask]
98
+
99
+ # countours = skimage.morphology.binary.binary_dilation(binary_mask) - binary_mask
100
+ countours = binary_dilation(binary_mask) ^ binary_mask
101
+ # countours = cv2.dilate(binary_mask, cv2.getStructuringElement(cv2.MORPH_CROSS,(3,3))) - binary_mask
102
+ im_overlay[countours, :] = 0
103
+
104
+ return im_overlay.astype(image.dtype)
105
+
106
+
107
+ output = output.astype(np.uint8) # (orig_h, orig_w), np.uint8
108
+ # Overlay the mask on the image
109
+ visualization = overlay_davis(img_ndarray, output) # red
110
+ visualization = Image.fromarray(visualization)
111
+ # show the visualization
112
+ #visualization.show()
113
+ # Save the visualization
114
+ visualization.save('./demo/demo_result.jpg')
115
+
116
+
117
+
118
+
LAVT-RIS/donghwa/args.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def get_parser():
5
+ parser = argparse.ArgumentParser(description='LAVT training and testing')
6
+ parser.add_argument('--amsgrad', action='store_true',
7
+ help='if true, set amsgrad to True in an Adam or AdamW optimizer.')
8
+ parser.add_argument('-b', '--batch-size', default=8, type=int)
9
+ parser.add_argument('--bert_tokenizer', default='bert-base-uncased', help='BERT tokenizer')
10
+ parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights')
11
+ parser.add_argument('--dataset', default='refcoco', help='refcoco, refcoco+, or refcocog')
12
+ parser.add_argument('--ddp_trained_weights', action='store_true',
13
+ help='Only needs specified when testing,'
14
+ 'whether the weights to be loaded are from a DDP-trained model')
15
+ parser.add_argument('--device', default='cuda:0', help='device') # only used when testing on a single machine
16
+ parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run')
17
+ parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs')
18
+ parser.add_argument('--img_size', default=480, type=int, help='input image size')
19
+ parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel')
20
+ parser.add_argument('--lr', default=0.00005, type=float, help='the initial learning rate')
21
+ parser.add_argument('--mha', default='', help='If specified, should be in the format of a-b-c-d, e.g., 4-4-4-4,'
22
+ 'where a, b, c, and d refer to the numbers of heads in stage-1,'
23
+ 'stage-2, stage-3, and stage-4 PWAMs')
24
+ parser.add_argument('--model', default='lavt', help='model: lavt, lavt_one')
25
+ parser.add_argument('--model_id', default='lavt', help='name to identify the model')
26
+ parser.add_argument('--output-dir', default='./checkpoints/', help='path where to save checkpoint weights')
27
+ parser.add_argument('--pin_mem', action='store_true',
28
+ help='If true, pin memory when using the data loader.')
29
+ parser.add_argument('--pretrained_swin_weights', default='',
30
+ help='path to pre-trained Swin backbone weights')
31
+ parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
32
+ parser.add_argument('--refer_data_root', default='./refer/data/', help='REFER dataset root directory')
33
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
34
+ parser.add_argument('--split', default='test', help='only used when testing')
35
+ parser.add_argument('--splitBy', default='unc', help='change to umd or google when the dataset is G-Ref (RefCOCOg)')
36
+ parser.add_argument('--swin_type', default='base',
37
+ help='tiny, small, base, or large variants of the Swin Transformer')
38
+ parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay',
39
+ dest='weight_decay')
40
+ parser.add_argument('--window12', action='store_true',
41
+ help='only needs specified when testing,'
42
+ 'when training, window size is inferred from pre-trained weights file name'
43
+ '(containing \'window12\'). Initialize Swin with window size 12 instead of the default 7.')
44
+ parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers')
45
+ parser.add_argument('--config',
46
+ default='path to xxx.yaml',
47
+ type=str,
48
+ help='config file')
49
+ return parser
50
+
51
+ # -----------------------------------------------------------------------------
52
+ # Functions for parsing args
53
+ # -----------------------------------------------------------------------------
54
+ import copy
55
+ import os
56
+ from ast import literal_eval
57
+
58
+ import yaml
59
+
60
+
61
+ class CfgNode(dict):
62
+ """
63
+ CfgNode represents an internal node in the configuration tree. It's a simple
64
+ dict-like container that allows for attribute-based access to keys.
65
+ """
66
+ def __init__(self, init_dict=None, key_list=None, new_allowed=False):
67
+ # Recursively convert nested dictionaries in init_dict into CfgNodes
68
+ init_dict = {} if init_dict is None else init_dict
69
+ key_list = [] if key_list is None else key_list
70
+ for k, v in init_dict.items():
71
+ if type(v) is dict:
72
+ # Convert dict to CfgNode
73
+ init_dict[k] = CfgNode(v, key_list=key_list + [k])
74
+ super(CfgNode, self).__init__(init_dict)
75
+
76
+ def __getattr__(self, name):
77
+ if name in self:
78
+ return self[name]
79
+ else:
80
+ raise AttributeError(name)
81
+
82
+ def __setattr__(self, name, value):
83
+ self[name] = value
84
+
85
+ def __str__(self):
86
+ def _indent(s_, num_spaces):
87
+ s = s_.split("\n")
88
+ if len(s) == 1:
89
+ return s_
90
+ first = s.pop(0)
91
+ s = [(num_spaces * " ") + line for line in s]
92
+ s = "\n".join(s)
93
+ s = first + "\n" + s
94
+ return s
95
+
96
+ r = ""
97
+ s = []
98
+ for k, v in sorted(self.items()):
99
+ seperator = "\n" if isinstance(v, CfgNode) else " "
100
+ attr_str = "{}:{}{}".format(str(k), seperator, str(v))
101
+ attr_str = _indent(attr_str, 2)
102
+ s.append(attr_str)
103
+ r += "\n".join(s)
104
+ return r
105
+
106
+ def __repr__(self):
107
+ return "{}({})".format(self.__class__.__name__,
108
+ super(CfgNode, self).__repr__())
109
+
110
+
111
+ def load_cfg_from_cfg_file(file):
112
+ cfg = {}
113
+ assert os.path.isfile(file) and file.endswith('.yaml'), \
114
+ '{} is not a yaml file'.format(file)
115
+
116
+ with open(file, 'r') as f:
117
+ cfg_from_file = yaml.safe_load(f)
118
+
119
+ for key in cfg_from_file:
120
+ for k, v in cfg_from_file[key].items():
121
+ cfg[k] = v
122
+
123
+ cfg = CfgNode(cfg)
124
+ return cfg
125
+
126
+
127
+ def merge_cfg_from_list(cfg, cfg_list):
128
+ new_cfg = copy.deepcopy(cfg)
129
+ assert len(cfg_list) % 2 == 0
130
+ for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
131
+ subkey = full_key.split('.')[-1]
132
+ assert subkey in cfg, 'Non-existent key: {}'.format(full_key)
133
+ value = _decode_cfg_value(v)
134
+ value = _check_and_coerce_cfg_value_type(value, cfg[subkey], subkey,
135
+ full_key)
136
+ setattr(new_cfg, subkey, value)
137
+
138
+ return new_cfg
139
+
140
+
141
+ def _decode_cfg_value(v):
142
+ """Decodes a raw config value (e.g., from a yaml config files or command
143
+ line argument) into a Python object.
144
+ """
145
+ # All remaining processing is only applied to strings
146
+ if not isinstance(v, str):
147
+ return v
148
+ # Try to interpret `v` as a:
149
+ # string, number, tuple, list, dict, boolean, or None
150
+ try:
151
+ v = literal_eval(v)
152
+ # The following two excepts allow v to pass through when it represents a
153
+ # string.
154
+ #
155
+ # Longer explanation:
156
+ # The type of v is always a string (before calling literal_eval), but
157
+ # sometimes it *represents* a string and other times a data structure, like
158
+ # a list. In the case that v represents a string, what we got back from the
159
+ # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is
160
+ # ok with '"foo"', but will raise a ValueError if given 'foo'. In other
161
+ # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval
162
+ # will raise a SyntaxError.
163
+ except ValueError:
164
+ pass
165
+ except SyntaxError:
166
+ pass
167
+ return v
168
+
169
+
170
+ def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
171
+ """Checks that `replacement`, which is intended to replace `original` is of
172
+ the right type. The type is correct if it matches exactly or is one of a few
173
+ cases in which the type can be easily coerced.
174
+ """
175
+ original_type = type(original)
176
+ replacement_type = type(replacement)
177
+
178
+ # The types must match (with some exceptions)
179
+ if replacement_type == original_type:
180
+ return replacement
181
+
182
+ # Cast replacement from from_type to to_type if the replacement and original
183
+ # types match from_type and to_type
184
+ def conditional_cast(from_type, to_type):
185
+ if replacement_type == from_type and original_type == to_type:
186
+ return True, to_type(replacement)
187
+ else:
188
+ return False, None
189
+
190
+ # Conditionally casts
191
+ # list <-> tuple
192
+ casts = [(tuple, list), (list, tuple)]
193
+ # For py2: allow converting from str (bytes) to a unicode string
194
+ try:
195
+ casts.append((str, unicode)) # noqa: F821
196
+ except Exception:
197
+ pass
198
+
199
+ for (from_type, to_type) in casts:
200
+ converted, converted_value = conditional_cast(from_type, to_type)
201
+ if converted:
202
+ return converted_value
203
+
204
+ raise ValueError(
205
+ "Type mismatch ({} vs. {}) with values ({} vs. {}) for config "
206
+ "key: {}".format(original_type, replacement_type, original,
207
+ replacement, full_key))
208
+
209
+
210
+ if __name__ == "__main__":
211
+ parser = get_parser()
212
+ args_dict = parser.parse_args()
LAVT-RIS/donghwa/config/__pycache__/utils.cpython-37.pyc ADDED
Binary file (4.64 kB). View file
 
LAVT-RIS/donghwa/config/n_obj/n_12.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ n_obj_bin : n_12
LAVT-RIS/donghwa/config/n_obj/n_34.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ n_obj_bin : n_34
LAVT-RIS/donghwa/config/n_obj/n_56.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ n_obj_bin : n_56