liuyizhang commited on
Commit
1ce5e18
1 Parent(s): 77de6b0

add transformers_4_35_0

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +2 -0
  2. kosmos_utils.py +7 -2
  3. transformers_4_35_0/__init__.py +0 -0
  4. transformers_4_35_0/activations.py +251 -0
  5. transformers_4_35_0/activations_tf.py +134 -0
  6. transformers_4_35_0/audio_utils.py +721 -0
  7. transformers_4_35_0/benchmark/__init__.py +0 -0
  8. transformers_4_35_0/benchmark/benchmark.py +271 -0
  9. transformers_4_35_0/benchmark/benchmark_args.py +114 -0
  10. transformers_4_35_0/benchmark/benchmark_args_tf.py +136 -0
  11. transformers_4_35_0/benchmark/benchmark_args_utils.py +166 -0
  12. transformers_4_35_0/benchmark/benchmark_tf.py +303 -0
  13. transformers_4_35_0/benchmark/benchmark_utils.py +914 -0
  14. transformers_4_35_0/commands/__init__.py +27 -0
  15. transformers_4_35_0/commands/add_new_model.py +259 -0
  16. transformers_4_35_0/commands/add_new_model_like.py +1763 -0
  17. transformers_4_35_0/commands/convert.py +184 -0
  18. transformers_4_35_0/commands/download.py +56 -0
  19. transformers_4_35_0/commands/env.py +143 -0
  20. transformers_4_35_0/commands/lfs.py +226 -0
  21. transformers_4_35_0/commands/pt_to_tf.py +425 -0
  22. transformers_4_35_0/commands/run.py +110 -0
  23. transformers_4_35_0/commands/serving.py +228 -0
  24. transformers_4_35_0/commands/train.py +158 -0
  25. transformers_4_35_0/commands/transformers_cli.py +59 -0
  26. transformers_4_35_0/commands/user.py +197 -0
  27. transformers_4_35_0/configuration_utils.py +1075 -0
  28. transformers_4_35_0/convert_graph_to_onnx.py +569 -0
  29. transformers_4_35_0/convert_pytorch_checkpoint_to_tf2.py +492 -0
  30. transformers_4_35_0/convert_slow_tokenizer.py +1318 -0
  31. transformers_4_35_0/convert_slow_tokenizers_checkpoints_to_fast.py +126 -0
  32. transformers_4_35_0/convert_tf_hub_seq_to_seq_bert_to_pytorch.py +88 -0
  33. transformers_4_35_0/data/__init__.py +44 -0
  34. transformers_4_35_0/data/data_collator.py +1535 -0
  35. transformers_4_35_0/data/datasets/__init__.py +23 -0
  36. transformers_4_35_0/data/datasets/glue.py +161 -0
  37. transformers_4_35_0/data/datasets/language_modeling.py +530 -0
  38. transformers_4_35_0/data/datasets/squad.py +229 -0
  39. transformers_4_35_0/data/metrics/__init__.py +98 -0
  40. transformers_4_35_0/data/metrics/squad_metrics.py +780 -0
  41. transformers_4_35_0/data/processors/__init__.py +18 -0
  42. transformers_4_35_0/data/processors/glue.py +643 -0
  43. transformers_4_35_0/data/processors/squad.py +845 -0
  44. transformers_4_35_0/data/processors/utils.py +349 -0
  45. transformers_4_35_0/data/processors/xnli.py +97 -0
  46. transformers_4_35_0/debug_utils.py +346 -0
  47. transformers_4_35_0/deepspeed.py +40 -0
  48. transformers_4_35_0/dependency_versions_check.py +63 -0
  49. transformers_4_35_0/dependency_versions_table.py +90 -0
  50. transformers_4_35_0/dynamic_module_utils.py +624 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ *.pyc
kosmos_utils.py CHANGED
@@ -1,11 +1,16 @@
1
  import random
2
  import numpy as np
3
- import os
4
  import requests
5
  import torch
6
  import torchvision.transforms as torchvision_T
7
  from PIL import Image
8
- from transformers import AutoProcessor, AutoModelForVision2Seq
 
 
 
 
 
9
  import cv2
10
  import ast
11
 
 
1
  import random
2
  import numpy as np
3
+ import os,sys
4
  import requests
5
  import torch
6
  import torchvision.transforms as torchvision_T
7
  from PIL import Image
8
+
9
+ # from transformers import AutoProcessor, AutoModelForVision2Seq
10
+ import subprocess, io, os, sys, time
11
+ sys.path.insert(0, './transformers_4_35_0')
12
+ from transformers_4_35_0 import AutoProcessor, AutoModelForVision2Seq
13
+
14
  import cv2
15
  import ast
16
 
transformers_4_35_0/__init__.py ADDED
The diff for this file is too large to render. See raw diff
 
transformers_4_35_0/activations.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from collections import OrderedDict
17
+
18
+ import torch
19
+ from packaging import version
20
+ from torch import Tensor, nn
21
+
22
+ from .utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class PytorchGELUTanh(nn.Module):
29
+ """
30
+ A fast C implementation of the tanh approximation of the GeLU activation function. See
31
+ https://arxiv.org/abs/1606.08415.
32
+
33
+ This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
34
+ match due to rounding errors.
35
+ """
36
+
37
+ def __init__(self):
38
+ super().__init__()
39
+ if version.parse(torch.__version__) < version.parse("1.12.0"):
40
+ raise ImportError(
41
+ f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
42
+ "PytorchGELUTanh. Please upgrade torch."
43
+ )
44
+
45
+ def forward(self, input: Tensor) -> Tensor:
46
+ return nn.functional.gelu(input, approximate="tanh")
47
+
48
+
49
+ class NewGELUActivation(nn.Module):
50
+ """
51
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
52
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
53
+ """
54
+
55
+ def forward(self, input: Tensor) -> Tensor:
56
+ return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
57
+
58
+
59
+ class GELUActivation(nn.Module):
60
+ """
61
+ Original Implementation of the GELU activation function in Google BERT repo when initially created. For
62
+ information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
63
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
64
+ Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
65
+ """
66
+
67
+ def __init__(self, use_gelu_python: bool = False):
68
+ super().__init__()
69
+ if use_gelu_python:
70
+ self.act = self._gelu_python
71
+ else:
72
+ self.act = nn.functional.gelu
73
+
74
+ def _gelu_python(self, input: Tensor) -> Tensor:
75
+ return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
76
+
77
+ def forward(self, input: Tensor) -> Tensor:
78
+ return self.act(input)
79
+
80
+
81
+ class FastGELUActivation(nn.Module):
82
+ """
83
+ Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
84
+ """
85
+
86
+ def forward(self, input: Tensor) -> Tensor:
87
+ return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
88
+
89
+
90
+ class QuickGELUActivation(nn.Module):
91
+ """
92
+ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
93
+ """
94
+
95
+ def forward(self, input: Tensor) -> Tensor:
96
+ return input * torch.sigmoid(1.702 * input)
97
+
98
+
99
+ class ClippedGELUActivation(nn.Module):
100
+ """
101
+ Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
102
+ it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
103
+ https://arxiv.org/abs/2004.09602.
104
+
105
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
106
+ initially created.
107
+
108
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
109
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
110
+ """
111
+
112
+ def __init__(self, min: float, max: float):
113
+ if min > max:
114
+ raise ValueError(f"min should be < max (got min: {min}, max: {max})")
115
+
116
+ super().__init__()
117
+ self.min = min
118
+ self.max = max
119
+
120
+ def forward(self, x: Tensor) -> Tensor:
121
+ return torch.clip(gelu(x), self.min, self.max)
122
+
123
+
124
+ class AccurateGELUActivation(nn.Module):
125
+ """
126
+ Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
127
+ https://github.com/hendrycks/GELUs
128
+
129
+ Implemented along with MEGA (Moving Average Equipped Gated Attention)
130
+ """
131
+
132
+ def __init__(self):
133
+ super().__init__()
134
+ self.precomputed_constant = math.sqrt(2 / math.pi)
135
+
136
+ def forward(self, input: Tensor) -> Tensor:
137
+ return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
138
+
139
+
140
+ class SiLUActivation(nn.Module):
141
+ """
142
+ See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
143
+ Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
144
+ Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
145
+ Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
146
+ later.
147
+ """
148
+
149
+ def forward(self, input: Tensor) -> Tensor:
150
+ return nn.functional.silu(input)
151
+
152
+
153
+ class MishActivation(nn.Module):
154
+ """
155
+ See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
156
+ visit the official repository for the paper: https://github.com/digantamisra98/Mish
157
+ """
158
+
159
+ def __init__(self):
160
+ super().__init__()
161
+ if version.parse(torch.__version__) < version.parse("1.9.0"):
162
+ self.act = self._mish_python
163
+ else:
164
+ self.act = nn.functional.mish
165
+
166
+ def _mish_python(self, input: Tensor) -> Tensor:
167
+ return input * torch.tanh(nn.functional.softplus(input))
168
+
169
+ def forward(self, input: Tensor) -> Tensor:
170
+ return self.act(input)
171
+
172
+
173
+ class LinearActivation(nn.Module):
174
+ """
175
+ Applies the linear activation function, i.e. forwarding input directly to output.
176
+ """
177
+
178
+ def forward(self, input: Tensor) -> Tensor:
179
+ return input
180
+
181
+
182
+ class LaplaceActivation(nn.Module):
183
+ """
184
+ Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
185
+ https://arxiv.org/abs/2209.10655
186
+
187
+ Inspired by squared relu, but with bounded range and gradient for better stability
188
+ """
189
+
190
+ def forward(self, input, mu=0.707107, sigma=0.282095):
191
+ input = (input - mu).div(sigma * math.sqrt(2.0))
192
+ return 0.5 * (1.0 + torch.erf(input))
193
+
194
+
195
+ class ReLUSquaredActivation(nn.Module):
196
+ """
197
+ Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
198
+ """
199
+
200
+ def forward(self, input):
201
+ relu_applied = nn.functional.relu(input)
202
+ squared = torch.square(relu_applied)
203
+ return squared
204
+
205
+
206
+ class ClassInstantier(OrderedDict):
207
+ def __getitem__(self, key):
208
+ content = super().__getitem__(key)
209
+ cls, kwargs = content if isinstance(content, tuple) else (content, {})
210
+ return cls(**kwargs)
211
+
212
+
213
+ ACT2CLS = {
214
+ "gelu": GELUActivation,
215
+ "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
216
+ "gelu_fast": FastGELUActivation,
217
+ "gelu_new": NewGELUActivation,
218
+ "gelu_python": (GELUActivation, {"use_gelu_python": True}),
219
+ "gelu_pytorch_tanh": PytorchGELUTanh,
220
+ "gelu_accurate": AccurateGELUActivation,
221
+ "laplace": LaplaceActivation,
222
+ "linear": LinearActivation,
223
+ "mish": MishActivation,
224
+ "quick_gelu": QuickGELUActivation,
225
+ "relu": nn.ReLU,
226
+ "relu2": ReLUSquaredActivation,
227
+ "relu6": nn.ReLU6,
228
+ "sigmoid": nn.Sigmoid,
229
+ "silu": SiLUActivation,
230
+ "swish": SiLUActivation,
231
+ "tanh": nn.Tanh,
232
+ }
233
+ ACT2FN = ClassInstantier(ACT2CLS)
234
+
235
+
236
+ def get_activation(activation_string):
237
+ if activation_string in ACT2FN:
238
+ return ACT2FN[activation_string]
239
+ else:
240
+ raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
241
+
242
+
243
+ # For backwards compatibility with: from activations import gelu_python
244
+ gelu_python = get_activation("gelu_python")
245
+ gelu_new = get_activation("gelu_new")
246
+ gelu = get_activation("gelu")
247
+ gelu_fast = get_activation("gelu_fast")
248
+ quick_gelu = get_activation("quick_gelu")
249
+ silu = get_activation("silu")
250
+ mish = get_activation("mish")
251
+ linear_act = get_activation("linear")
transformers_4_35_0/activations_tf.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import tensorflow as tf
18
+ from packaging import version
19
+
20
+
21
+ def _gelu(x):
22
+ """
23
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
24
+ initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
25
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
26
+ https://arxiv.org/abs/1606.08415
27
+ """
28
+ x = tf.convert_to_tensor(x)
29
+ cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))
30
+
31
+ return x * cdf
32
+
33
+
34
+ def _gelu_new(x):
35
+ """
36
+ Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841
37
+
38
+ Args:
39
+ x: float Tensor to perform activation
40
+
41
+ Returns:
42
+ `x` with the GELU activation applied.
43
+ """
44
+ x = tf.convert_to_tensor(x)
45
+ pi = tf.cast(math.pi, x.dtype)
46
+ coeff = tf.cast(0.044715, x.dtype)
47
+ cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3))))
48
+
49
+ return x * cdf
50
+
51
+
52
+ def mish(x):
53
+ x = tf.convert_to_tensor(x)
54
+
55
+ return x * tf.tanh(tf.math.softplus(x))
56
+
57
+
58
+ def gelu_fast(x):
59
+ x = tf.convert_to_tensor(x)
60
+ coeff1 = tf.cast(0.044715, x.dtype)
61
+ coeff2 = tf.cast(0.7978845608, x.dtype)
62
+
63
+ return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x)))
64
+
65
+
66
+ def quick_gelu(x):
67
+ x = tf.convert_to_tensor(x)
68
+ coeff = tf.cast(1.702, x.dtype)
69
+ return x * tf.math.sigmoid(coeff * x)
70
+
71
+
72
+ def gelu_10(x):
73
+ """
74
+ Clip the range of possible GeLU outputs between [-10, 10]. This is especially useful for quantization purpose, as
75
+ it allows mapping 2 negatives values in the GeLU spectrum. For more information on this trick, please refer to
76
+ https://arxiv.org/abs/2004.09602
77
+
78
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
79
+ initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
80
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
81
+ https://arxiv.org/abs/1606.08415 :param x: :return:
82
+ """
83
+ return tf.clip_by_value(_gelu(x), -10, 10)
84
+
85
+
86
+ def glu(x, axis=-1):
87
+ """
88
+ Gated Linear Unit. Implementation as defined in the original paper (see https://arxiv.org/abs/1612.08083), where
89
+ the input `x` is split in two halves across a dimension (`axis`), A and B, returning A * sigmoid(B).
90
+
91
+ Args:
92
+ `x`: float Tensor to perform activation
93
+ `axis`: dimension across which `x` be split in half
94
+
95
+ Returns:
96
+ `x` with the GLU activation applied (with its size halved across the dimension `axis`).
97
+ """
98
+ a, b = tf.split(x, 2, axis=axis)
99
+ return a * tf.math.sigmoid(b)
100
+
101
+
102
+ if version.parse(tf.version.VERSION) >= version.parse("2.4"):
103
+
104
+ def approximate_gelu_wrap(x):
105
+ return tf.keras.activations.gelu(x, approximate=True)
106
+
107
+ gelu = tf.keras.activations.gelu
108
+ gelu_new = approximate_gelu_wrap
109
+ else:
110
+ gelu = _gelu
111
+ gelu_new = _gelu_new
112
+
113
+
114
+ ACT2FN = {
115
+ "gelu": gelu,
116
+ "gelu_10": gelu_10,
117
+ "gelu_fast": gelu_fast,
118
+ "gelu_new": gelu_new,
119
+ "glu": glu,
120
+ "mish": mish,
121
+ "quick_gelu": quick_gelu,
122
+ "relu": tf.keras.activations.relu,
123
+ "sigmoid": tf.keras.activations.sigmoid,
124
+ "silu": tf.keras.activations.swish,
125
+ "swish": tf.keras.activations.swish,
126
+ "tanh": tf.keras.activations.tanh,
127
+ }
128
+
129
+
130
+ def get_tf_activation(activation_string):
131
+ if activation_string in ACT2FN:
132
+ return ACT2FN[activation_string]
133
+ else:
134
+ raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
transformers_4_35_0/audio_utils.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team and the librosa & torchaudio authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks
17
+ and remove unnecessary dependencies.
18
+ """
19
+ import warnings
20
+ from typing import Optional, Union
21
+
22
+ import numpy as np
23
+
24
+
25
+ def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
26
+ """
27
+ Convert frequency from hertz to mels.
28
+
29
+ Args:
30
+ freq (`float` or `np.ndarray`):
31
+ The frequency, or multiple frequencies, in hertz (Hz).
32
+ mel_scale (`str`, *optional*, defaults to `"htk"`):
33
+ The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
34
+
35
+ Returns:
36
+ `float` or `np.ndarray`: The frequencies on the mel scale.
37
+ """
38
+
39
+ if mel_scale not in ["slaney", "htk", "kaldi"]:
40
+ raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
41
+
42
+ if mel_scale == "htk":
43
+ return 2595.0 * np.log10(1.0 + (freq / 700.0))
44
+ elif mel_scale == "kaldi":
45
+ return 1127.0 * np.log(1.0 + (freq / 700.0))
46
+
47
+ min_log_hertz = 1000.0
48
+ min_log_mel = 15.0
49
+ logstep = 27.0 / np.log(6.4)
50
+ mels = 3.0 * freq / 200.0
51
+
52
+ if isinstance(freq, np.ndarray):
53
+ log_region = freq >= min_log_hertz
54
+ mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
55
+ elif freq >= min_log_hertz:
56
+ mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
57
+
58
+ return mels
59
+
60
+
61
+ def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
62
+ """
63
+ Convert frequency from mels to hertz.
64
+
65
+ Args:
66
+ mels (`float` or `np.ndarray`):
67
+ The frequency, or multiple frequencies, in mels.
68
+ mel_scale (`str`, *optional*, `"htk"`):
69
+ The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
70
+
71
+ Returns:
72
+ `float` or `np.ndarray`: The frequencies in hertz.
73
+ """
74
+
75
+ if mel_scale not in ["slaney", "htk", "kaldi"]:
76
+ raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
77
+
78
+ if mel_scale == "htk":
79
+ return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
80
+ elif mel_scale == "kaldi":
81
+ return 700.0 * (np.exp(mels / 1127.0) - 1.0)
82
+
83
+ min_log_hertz = 1000.0
84
+ min_log_mel = 15.0
85
+ logstep = np.log(6.4) / 27.0
86
+ freq = 200.0 * mels / 3.0
87
+
88
+ if isinstance(mels, np.ndarray):
89
+ log_region = mels >= min_log_mel
90
+ freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel))
91
+ elif mels >= min_log_mel:
92
+ freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel))
93
+
94
+ return freq
95
+
96
+
97
+ def _create_triangular_filter_bank(fft_freqs: np.ndarray, filter_freqs: np.ndarray) -> np.ndarray:
98
+ """
99
+ Creates a triangular filter bank.
100
+
101
+ Adapted from *torchaudio* and *librosa*.
102
+
103
+ Args:
104
+ fft_freqs (`np.ndarray` of shape `(num_frequency_bins,)`):
105
+ Discrete frequencies of the FFT bins in Hz.
106
+ filter_freqs (`np.ndarray` of shape `(num_mel_filters,)`):
107
+ Center frequencies of the triangular filters to create, in Hz.
108
+
109
+ Returns:
110
+ `np.ndarray` of shape `(num_frequency_bins, num_mel_filters)`
111
+ """
112
+ filter_diff = np.diff(filter_freqs)
113
+ slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
114
+ down_slopes = -slopes[:, :-2] / filter_diff[:-1]
115
+ up_slopes = slopes[:, 2:] / filter_diff[1:]
116
+ return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
117
+
118
+
119
+ def mel_filter_bank(
120
+ num_frequency_bins: int,
121
+ num_mel_filters: int,
122
+ min_frequency: float,
123
+ max_frequency: float,
124
+ sampling_rate: int,
125
+ norm: Optional[str] = None,
126
+ mel_scale: str = "htk",
127
+ triangularize_in_mel_space: bool = False,
128
+ ) -> np.ndarray:
129
+ """
130
+ Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and
131
+ various implementation exist, which differ in the number of filters, the shape of the filters, the way the filters
132
+ are spaced, the bandwidth of the filters, and the manner in which the spectrum is warped. The goal of these
133
+ features is to approximate the non-linear human perception of the variation in pitch with respect to the frequency.
134
+
135
+ Different banks of mel filters were introduced in the literature. The following variations are supported:
136
+
137
+ - MFCC FB-20: introduced in 1980 by Davis and Mermelstein, it assumes a sampling frequency of 10 kHz and a speech
138
+ bandwidth of `[0, 4600]` Hz.
139
+ - MFCC FB-24 HTK: from the Cambridge HMM Toolkit (HTK) (1995) uses a filter bank of 24 filters for a speech
140
+ bandwidth of `[0, 8000]` Hz. This assumes sampling rate ≥ 16 kHz.
141
+ - MFCC FB-40: from the Auditory Toolbox for MATLAB written by Slaney in 1998, assumes a sampling rate of 16 kHz and
142
+ speech bandwidth of `[133, 6854]` Hz. This version also includes area normalization.
143
+ - HFCC-E FB-29 (Human Factor Cepstral Coefficients) of Skowronski and Harris (2004), assumes a sampling rate of
144
+ 12.5 kHz and speech bandwidth of `[0, 6250]` Hz.
145
+
146
+ This code is adapted from *torchaudio* and *librosa*. Note that the default parameters of torchaudio's
147
+ `melscale_fbanks` implement the `"htk"` filters while librosa uses the `"slaney"` implementation.
148
+
149
+ Args:
150
+ num_frequency_bins (`int`):
151
+ Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
152
+ num_mel_filters (`int`):
153
+ Number of mel filters to generate.
154
+ min_frequency (`float`):
155
+ Lowest frequency of interest in Hz.
156
+ max_frequency (`float`):
157
+ Highest frequency of interest in Hz. This should not exceed `sampling_rate / 2`.
158
+ sampling_rate (`int`):
159
+ Sample rate of the audio waveform.
160
+ norm (`str`, *optional*):
161
+ If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization).
162
+ mel_scale (`str`, *optional*, defaults to `"htk"`):
163
+ The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
164
+ triangularize_in_mel_space (`bool`, *optional*, defaults to `False`):
165
+ If this option is enabled, the triangular filter is applied in mel space rather than frequency space. This
166
+ should be set to `true` in order to get the same results as `torchaudio` when computing mel filters.
167
+
168
+ Returns:
169
+ `np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a
170
+ projection matrix to go from a spectrogram to a mel spectrogram.
171
+ """
172
+ if norm is not None and norm != "slaney":
173
+ raise ValueError('norm must be one of None or "slaney"')
174
+
175
+ # center points of the triangular mel filters
176
+ mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
177
+ mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
178
+ mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
179
+ filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
180
+
181
+ if triangularize_in_mel_space:
182
+ # frequencies of FFT bins in Hz, but filters triangularized in mel space
183
+ fft_bin_width = sampling_rate / (num_frequency_bins * 2)
184
+ fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
185
+ filter_freqs = mel_freqs
186
+ else:
187
+ # frequencies of FFT bins in Hz
188
+ fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
189
+
190
+ mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
191
+
192
+ if norm is not None and norm == "slaney":
193
+ # Slaney-style mel is scaled to be approx constant energy per channel
194
+ enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters])
195
+ mel_filters *= np.expand_dims(enorm, 0)
196
+
197
+ if (mel_filters.max(axis=0) == 0.0).any():
198
+ warnings.warn(
199
+ "At least one mel filter has all zero values. "
200
+ f"The value for `num_mel_filters` ({num_mel_filters}) may be set too high. "
201
+ f"Or, the value for `num_frequency_bins` ({num_frequency_bins}) may be set too low."
202
+ )
203
+
204
+ return mel_filters
205
+
206
+
207
+ def optimal_fft_length(window_length: int) -> int:
208
+ """
209
+ Finds the best FFT input size for a given `window_length`. This function takes a given window length and, if not
210
+ already a power of two, rounds it up to the next power or two.
211
+
212
+ The FFT algorithm works fastest when the length of the input is a power of two, which may be larger than the size
213
+ of the window or analysis frame. For example, if the window is 400 samples, using an FFT input size of 512 samples
214
+ is more optimal than an FFT size of 400 samples. Using a larger FFT size does not affect the detected frequencies,
215
+ it simply gives a higher frequency resolution (i.e. the frequency bins are smaller).
216
+ """
217
+ return 2 ** int(np.ceil(np.log2(window_length)))
218
+
219
+
220
+ def window_function(
221
+ window_length: int,
222
+ name: str = "hann",
223
+ periodic: bool = True,
224
+ frame_length: Optional[int] = None,
225
+ center: bool = True,
226
+ ) -> np.ndarray:
227
+ """
228
+ Returns an array containing the specified window. This window is intended to be used with `stft`.
229
+
230
+ The following window types are supported:
231
+
232
+ - `"boxcar"`: a rectangular window
233
+ - `"hamming"`: the Hamming window
234
+ - `"hann"`: the Hann window
235
+ - `"povey"`: the Povey window
236
+
237
+ Args:
238
+ window_length (`int`):
239
+ The length of the window in samples.
240
+ name (`str`, *optional*, defaults to `"hann"`):
241
+ The name of the window function.
242
+ periodic (`bool`, *optional*, defaults to `True`):
243
+ Whether the window is periodic or symmetric.
244
+ frame_length (`int`, *optional*):
245
+ The length of the analysis frames in samples. Provide a value for `frame_length` if the window is smaller
246
+ than the frame length, so that it will be zero-padded.
247
+ center (`bool`, *optional*, defaults to `True`):
248
+ Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided.
249
+
250
+ Returns:
251
+ `np.ndarray` of shape `(window_length,)` or `(frame_length,)` containing the window.
252
+ """
253
+ length = window_length + 1 if periodic else window_length
254
+
255
+ if name == "boxcar":
256
+ window = np.ones(length)
257
+ elif name in ["hamming", "hamming_window"]:
258
+ window = np.hamming(length)
259
+ elif name in ["hann", "hann_window"]:
260
+ window = np.hanning(length)
261
+ elif name in ["povey"]:
262
+ window = np.power(np.hanning(length), 0.85)
263
+ else:
264
+ raise ValueError(f"Unknown window function '{name}'")
265
+
266
+ if periodic:
267
+ window = window[:-1]
268
+
269
+ if frame_length is None:
270
+ return window
271
+
272
+ if window_length > frame_length:
273
+ raise ValueError(
274
+ f"Length of the window ({window_length}) may not be larger than frame_length ({frame_length})"
275
+ )
276
+
277
+ padded_window = np.zeros(frame_length)
278
+ offset = (frame_length - window_length) // 2 if center else 0
279
+ padded_window[offset : offset + window_length] = window
280
+ return padded_window
281
+
282
+
283
+ # TODO This method does not support batching yet as we are mainly focused on inference.
284
+ def spectrogram(
285
+ waveform: np.ndarray,
286
+ window: np.ndarray,
287
+ frame_length: int,
288
+ hop_length: int,
289
+ fft_length: Optional[int] = None,
290
+ power: Optional[float] = 1.0,
291
+ center: bool = True,
292
+ pad_mode: str = "reflect",
293
+ onesided: bool = True,
294
+ preemphasis: Optional[float] = None,
295
+ mel_filters: Optional[np.ndarray] = None,
296
+ mel_floor: float = 1e-10,
297
+ log_mel: Optional[str] = None,
298
+ reference: float = 1.0,
299
+ min_value: float = 1e-10,
300
+ db_range: Optional[float] = None,
301
+ remove_dc_offset: Optional[bool] = None,
302
+ dtype: np.dtype = np.float32,
303
+ ) -> np.ndarray:
304
+ """
305
+ Calculates a spectrogram over one waveform using the Short-Time Fourier Transform.
306
+
307
+ This function can create the following kinds of spectrograms:
308
+
309
+ - amplitude spectrogram (`power = 1.0`)
310
+ - power spectrogram (`power = 2.0`)
311
+ - complex-valued spectrogram (`power = None`)
312
+ - log spectrogram (use `log_mel` argument)
313
+ - mel spectrogram (provide `mel_filters`)
314
+ - log-mel spectrogram (provide `mel_filters` and `log_mel`)
315
+
316
+ How this works:
317
+
318
+ 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
319
+ - hop_length` samples.
320
+ 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
321
+ 3. The DFT is taken of each windowed frame.
322
+ 4. The results are stacked into a spectrogram.
323
+
324
+ We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
325
+
326
+ - The analysis frame. This is the size of the time slices that the input waveform is split into.
327
+ - The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
328
+ - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
329
+
330
+ In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
331
+ padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
332
+ typically the next power of two.
333
+
334
+ Note: This function is not optimized for speed yet. It should be mostly compatible with `librosa.stft` and
335
+ `torchaudio.functional.transforms.Spectrogram`, although it is more flexible due to the different ways spectrograms
336
+ can be constructed.
337
+
338
+ Args:
339
+ waveform (`np.ndarray` of shape `(length,)`):
340
+ The input waveform. This must be a single real-valued, mono waveform.
341
+ window (`np.ndarray` of shape `(frame_length,)`):
342
+ The windowing function to apply, including zero-padding if necessary. The actual window length may be
343
+ shorter than `frame_length`, but we're assuming the array has already been zero-padded.
344
+ frame_length (`int`):
345
+ The length of the analysis frames in samples. With librosa this is always equal to `fft_length` but we also
346
+ allow smaller sizes.
347
+ hop_length (`int`):
348
+ The stride between successive analysis frames in samples.
349
+ fft_length (`int`, *optional*):
350
+ The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have.
351
+ For optimal speed, this should be a power of two. If `None`, uses `frame_length`.
352
+ power (`float`, *optional*, defaults to 1.0):
353
+ If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `None`, returns
354
+ complex numbers.
355
+ center (`bool`, *optional*, defaults to `True`):
356
+ Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame
357
+ `t` will start at time `t * hop_length`.
358
+ pad_mode (`str`, *optional*, defaults to `"reflect"`):
359
+ Padding mode used when `center` is `True`. Possible values are: `"constant"` (pad with zeros), `"edge"`
360
+ (pad with edge values), `"reflect"` (pads with mirrored values).
361
+ onesided (`bool`, *optional*, defaults to `True`):
362
+ If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
363
+ frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
364
+ preemphasis (`float`, *optional*)
365
+ Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
366
+ mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
367
+ The mel filter bank. If supplied, applies a this filter bank to create a mel spectrogram.
368
+ mel_floor (`float`, *optional*, defaults to 1e-10):
369
+ Minimum value of mel frequency banks.
370
+ log_mel (`str`, *optional*):
371
+ How to convert the spectrogram to log scale. Possible options are: `None` (don't convert), `"log"` (take
372
+ the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels). Can only be
373
+ used when `power` is not `None`.
374
+ reference (`float`, *optional*, defaults to 1.0):
375
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
376
+ the loudest part to 0 dB. Must be greater than zero.
377
+ min_value (`float`, *optional*, defaults to `1e-10`):
378
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
379
+ `log(0)`. For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an
380
+ amplitude spectrogram, the value `1e-5` corresponds to -100 dB. Must be greater than zero.
381
+ db_range (`float`, *optional*):
382
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
383
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
384
+ remove_dc_offset (`bool`, *optional*):
385
+ Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in
386
+ order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters.
387
+ dtype (`np.dtype`, *optional*, defaults to `np.float32`):
388
+ Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be
389
+ `np.complex64`.
390
+
391
+ Returns:
392
+ `nd.array` containing a spectrogram of shape `(num_frequency_bins, length)` for a regular spectrogram or shape
393
+ `(num_mel_filters, length)` for a mel spectrogram.
394
+ """
395
+ window_length = len(window)
396
+
397
+ if fft_length is None:
398
+ fft_length = frame_length
399
+
400
+ if frame_length > fft_length:
401
+ raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
402
+
403
+ if window_length != frame_length:
404
+ raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
405
+
406
+ if hop_length <= 0:
407
+ raise ValueError("hop_length must be greater than zero")
408
+
409
+ if waveform.ndim != 1:
410
+ raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
411
+
412
+ if np.iscomplexobj(waveform):
413
+ raise ValueError("Complex-valued input waveforms are not currently supported")
414
+
415
+ # center pad the waveform
416
+ if center:
417
+ padding = [(int(frame_length // 2), int(frame_length // 2))]
418
+ waveform = np.pad(waveform, padding, mode=pad_mode)
419
+
420
+ # promote to float64, since np.fft uses float64 internally
421
+ waveform = waveform.astype(np.float64)
422
+ window = window.astype(np.float64)
423
+
424
+ # split waveform into frames of frame_length size
425
+ num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))
426
+
427
+ num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
428
+ spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)
429
+
430
+ # rfft is faster than fft
431
+ fft_func = np.fft.rfft if onesided else np.fft.fft
432
+ buffer = np.zeros(fft_length)
433
+
434
+ timestep = 0
435
+ for frame_idx in range(num_frames):
436
+ buffer[:frame_length] = waveform[timestep : timestep + frame_length]
437
+
438
+ if remove_dc_offset:
439
+ buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
440
+
441
+ if preemphasis is not None:
442
+ buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
443
+ buffer[0] *= 1 - preemphasis
444
+
445
+ buffer[:frame_length] *= window
446
+
447
+ spectrogram[frame_idx] = fft_func(buffer)
448
+ timestep += hop_length
449
+
450
+ # note: ** is much faster than np.power
451
+ if power is not None:
452
+ spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
453
+
454
+ spectrogram = spectrogram.T
455
+
456
+ if mel_filters is not None:
457
+ spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))
458
+
459
+ if power is not None and log_mel is not None:
460
+ if log_mel == "log":
461
+ spectrogram = np.log(spectrogram)
462
+ elif log_mel == "log10":
463
+ spectrogram = np.log10(spectrogram)
464
+ elif log_mel == "dB":
465
+ if power == 1.0:
466
+ spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)
467
+ elif power == 2.0:
468
+ spectrogram = power_to_db(spectrogram, reference, min_value, db_range)
469
+ else:
470
+ raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
471
+ else:
472
+ raise ValueError(f"Unknown log_mel option: {log_mel}")
473
+
474
+ spectrogram = np.asarray(spectrogram, dtype)
475
+
476
+ return spectrogram
477
+
478
+
479
+ def power_to_db(
480
+ spectrogram: np.ndarray,
481
+ reference: float = 1.0,
482
+ min_value: float = 1e-10,
483
+ db_range: Optional[float] = None,
484
+ ) -> np.ndarray:
485
+ """
486
+ Converts a power spectrogram to the decibel scale. This computes `10 * log10(spectrogram / reference)`, using basic
487
+ logarithm properties for numerical stability.
488
+
489
+ The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
490
+ linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
491
+ This means that large variations in energy may not sound all that different if the sound is loud to begin with.
492
+ This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
493
+
494
+ Based on the implementation of `librosa.power_to_db`.
495
+
496
+ Args:
497
+ spectrogram (`np.ndarray`):
498
+ The input power (mel) spectrogram. Note that a power spectrogram has the amplitudes squared!
499
+ reference (`float`, *optional*, defaults to 1.0):
500
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
501
+ the loudest part to 0 dB. Must be greater than zero.
502
+ min_value (`float`, *optional*, defaults to `1e-10`):
503
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
504
+ `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
505
+ db_range (`float`, *optional*):
506
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
507
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
508
+
509
+ Returns:
510
+ `np.ndarray`: the spectrogram in decibels
511
+ """
512
+ if reference <= 0.0:
513
+ raise ValueError("reference must be greater than zero")
514
+ if min_value <= 0.0:
515
+ raise ValueError("min_value must be greater than zero")
516
+
517
+ reference = max(min_value, reference)
518
+
519
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
520
+ spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
521
+
522
+ if db_range is not None:
523
+ if db_range <= 0.0:
524
+ raise ValueError("db_range must be greater than zero")
525
+ spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
526
+
527
+ return spectrogram
528
+
529
+
530
+ def amplitude_to_db(
531
+ spectrogram: np.ndarray,
532
+ reference: float = 1.0,
533
+ min_value: float = 1e-5,
534
+ db_range: Optional[float] = None,
535
+ ) -> np.ndarray:
536
+ """
537
+ Converts an amplitude spectrogram to the decibel scale. This computes `20 * log10(spectrogram / reference)`, using
538
+ basic logarithm properties for numerical stability.
539
+
540
+ The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
541
+ linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
542
+ This means that large variations in energy may not sound all that different if the sound is loud to begin with.
543
+ This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
544
+
545
+ Args:
546
+ spectrogram (`np.ndarray`):
547
+ The input amplitude (mel) spectrogram.
548
+ reference (`float`, *optional*, defaults to 1.0):
549
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
550
+ the loudest part to 0 dB. Must be greater than zero.
551
+ min_value (`float`, *optional*, defaults to `1e-5`):
552
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
553
+ `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
554
+ db_range (`float`, *optional*):
555
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
556
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
557
+
558
+ Returns:
559
+ `np.ndarray`: the spectrogram in decibels
560
+ """
561
+ if reference <= 0.0:
562
+ raise ValueError("reference must be greater than zero")
563
+ if min_value <= 0.0:
564
+ raise ValueError("min_value must be greater than zero")
565
+
566
+ reference = max(min_value, reference)
567
+
568
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
569
+ spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
570
+
571
+ if db_range is not None:
572
+ if db_range <= 0.0:
573
+ raise ValueError("db_range must be greater than zero")
574
+ spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
575
+
576
+ return spectrogram
577
+
578
+
579
+ ### deprecated functions below this line ###
580
+
581
+
582
+ def get_mel_filter_banks(
583
+ nb_frequency_bins: int,
584
+ nb_mel_filters: int,
585
+ frequency_min: float,
586
+ frequency_max: float,
587
+ sample_rate: int,
588
+ norm: Optional[str] = None,
589
+ mel_scale: str = "htk",
590
+ ) -> np.array:
591
+ warnings.warn(
592
+ "The function `get_mel_filter_banks` is deprecated and will be removed in version 4.31.0 of Transformers",
593
+ FutureWarning,
594
+ )
595
+ return mel_filter_bank(
596
+ num_frequency_bins=nb_frequency_bins,
597
+ num_mel_filters=nb_mel_filters,
598
+ min_frequency=frequency_min,
599
+ max_frequency=frequency_max,
600
+ sampling_rate=sample_rate,
601
+ norm=norm,
602
+ mel_scale=mel_scale,
603
+ )
604
+
605
+
606
+ def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int = 400, center: bool = True):
607
+ """
608
+ In order to compute the short time fourier transform, the waveform needs to be split in overlapping windowed
609
+ segments called `frames`.
610
+
611
+ The window length (window_length) defines how much of the signal is contained in each frame, while the hop length
612
+ defines the step between the beginning of each new frame.
613
+
614
+
615
+ Args:
616
+ waveform (`np.array` of shape `(sample_length,)`):
617
+ The raw waveform which will be split into smaller chunks.
618
+ hop_length (`int`, *optional*, defaults to 160):
619
+ Step between each window of the waveform.
620
+ fft_window_size (`int`, *optional*, defaults to 400):
621
+ Defines the size of the window.
622
+ center (`bool`, defaults to `True`):
623
+ Whether or not to center each frame around the middle of the frame. Centering is done by reflecting the
624
+ waveform on the left and on the right.
625
+
626
+ Return:
627
+ framed_waveform (`np.array` of shape `(waveform.shape // hop_length , fft_window_size)`):
628
+ The framed waveforms that can be fed to `np.fft`.
629
+ """
630
+ warnings.warn(
631
+ "The function `fram_wave` is deprecated and will be removed in version 4.31.0 of Transformers",
632
+ FutureWarning,
633
+ )
634
+ frames = []
635
+ for i in range(0, waveform.shape[0] + 1, hop_length):
636
+ if center:
637
+ half_window = (fft_window_size - 1) // 2 + 1
638
+ start = i - half_window if i > half_window else 0
639
+ end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0]
640
+ frame = waveform[start:end]
641
+ if start == 0:
642
+ padd_width = (-i + half_window, 0)
643
+ frame = np.pad(frame, pad_width=padd_width, mode="reflect")
644
+
645
+ elif end == waveform.shape[0]:
646
+ padd_width = (0, (i - waveform.shape[0] + half_window))
647
+ frame = np.pad(frame, pad_width=padd_width, mode="reflect")
648
+
649
+ else:
650
+ frame = waveform[i : i + fft_window_size]
651
+ frame_width = frame.shape[0]
652
+ if frame_width < waveform.shape[0]:
653
+ frame = np.lib.pad(
654
+ frame, pad_width=(0, fft_window_size - frame_width), mode="constant", constant_values=0
655
+ )
656
+ frames.append(frame)
657
+
658
+ frames = np.stack(frames, 0)
659
+ return frames
660
+
661
+
662
+ def stft(frames: np.array, windowing_function: np.array, fft_window_size: int = None):
663
+ """
664
+ Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same results
665
+ as `torch.stft`.
666
+
667
+ Args:
668
+ frames (`np.array` of dimension `(num_frames, fft_window_size)`):
669
+ A framed audio signal obtained using `audio_utils.fram_wav`.
670
+ windowing_function (`np.array` of dimension `(nb_frequency_bins, nb_mel_filters)`:
671
+ A array reprensenting the function that will be used to reduces the amplitude of the discontinuities at the
672
+ boundaries of each frame when computing the STFT. Each frame will be multiplied by the windowing_function.
673
+ For more information on the discontinuities, called *Spectral leakage*, refer to [this
674
+ tutorial]https://download.ni.com/evaluation/pxi/Understanding%20FFTs%20and%20Windowing.pdf
675
+ fft_window_size (`int`, *optional*):
676
+ Size of the window om which the Fourier transform is applied. This controls the frequency resolution of the
677
+ spectrogram. 400 means that the fourrier transform is computed on windows of 400 samples. The number of
678
+ frequency bins (`nb_frequency_bins`) used to divide the window into equal strips is equal to
679
+ `(1+fft_window_size)//2`. An increase of the fft_window_size slows the calculus time proportionnally.
680
+
681
+ Example:
682
+
683
+ ```python
684
+ >>> from transformers.audio_utils import stft, fram_wave
685
+ >>> import numpy as np
686
+
687
+ >>> audio = np.random.rand(50)
688
+ >>> fft_window_size = 10
689
+ >>> hop_length = 2
690
+ >>> framed_audio = fram_wave(audio, hop_length, fft_window_size)
691
+ >>> spectrogram = stft(framed_audio, np.hanning(fft_window_size + 1))
692
+ ```
693
+
694
+ Returns:
695
+ spectrogram (`np.ndarray`):
696
+ A spectrogram of shape `(num_frames, nb_frequency_bins)` obtained using the STFT algorithm
697
+ """
698
+ warnings.warn(
699
+ "The function `stft` is deprecated and will be removed in version 4.31.0 of Transformers",
700
+ FutureWarning,
701
+ )
702
+ frame_size = frames.shape[1]
703
+
704
+ if fft_window_size is None:
705
+ fft_window_size = frame_size
706
+
707
+ if fft_window_size < frame_size:
708
+ raise ValueError("FFT size must greater or equal the frame size")
709
+ # number of FFT bins to store
710
+ nb_frequency_bins = (fft_window_size >> 1) + 1
711
+
712
+ spectrogram = np.empty((len(frames), nb_frequency_bins), dtype=np.complex64)
713
+ fft_signal = np.zeros(fft_window_size)
714
+
715
+ for f, frame in enumerate(frames):
716
+ if windowing_function is not None:
717
+ np.multiply(frame, windowing_function, out=fft_signal[:frame_size])
718
+ else:
719
+ fft_signal[:frame_size] = frame
720
+ spectrogram[f] = np.fft.fft(fft_signal, axis=0)[:nb_frequency_bins]
721
+ return spectrogram.T
transformers_4_35_0/benchmark/__init__.py ADDED
File without changes
transformers_4_35_0/benchmark/benchmark.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Benchmarking the library on inference and training in PyTorch.
18
+ """
19
+
20
+
21
+ import timeit
22
+ from typing import Callable, Optional
23
+
24
+ from ..configuration_utils import PretrainedConfig
25
+ from ..models.auto.modeling_auto import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING
26
+ from ..utils import is_py3nvml_available, is_torch_available, logging
27
+ from .benchmark_utils import (
28
+ Benchmark,
29
+ Memory,
30
+ MemorySummary,
31
+ measure_peak_memory_cpu,
32
+ start_memory_tracing,
33
+ stop_memory_tracing,
34
+ )
35
+
36
+
37
+ if is_torch_available():
38
+ import torch
39
+
40
+ from .benchmark_args import PyTorchBenchmarkArguments
41
+
42
+
43
+ if is_py3nvml_available():
44
+ import py3nvml.py3nvml as nvml
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+ class PyTorchBenchmark(Benchmark):
51
+ args: PyTorchBenchmarkArguments
52
+ configs: PretrainedConfig
53
+ framework: str = "PyTorch"
54
+
55
+ @property
56
+ def framework_version(self):
57
+ return torch.__version__
58
+
59
+ def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
60
+ _inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
61
+ return self._measure_speed(_inference)
62
+
63
+ def _inference_memory(
64
+ self, model_name: str, batch_size: int, sequence_length: int
65
+ ) -> [Memory, Optional[MemorySummary]]:
66
+ _inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
67
+ return self._measure_memory(_inference)
68
+
69
+ def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
70
+ _train = self._prepare_train_func(model_name, batch_size, sequence_length)
71
+ return self._measure_speed(_train)
72
+
73
+ def _train_memory(
74
+ self, model_name: str, batch_size: int, sequence_length: int
75
+ ) -> [Memory, Optional[MemorySummary]]:
76
+ _train = self._prepare_train_func(model_name, batch_size, sequence_length)
77
+ return self._measure_memory(_train)
78
+
79
+ def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
80
+ config = self.config_dict[model_name]
81
+
82
+ if self.args.torchscript:
83
+ config.torchscript = True
84
+
85
+ has_model_class_in_config = (
86
+ hasattr(config, "architectures")
87
+ and isinstance(config.architectures, list)
88
+ and len(config.architectures) > 0
89
+ )
90
+ if not self.args.only_pretrain_model and has_model_class_in_config:
91
+ try:
92
+ model_class = config.architectures[0]
93
+ transformers_module = __import__("transformers", fromlist=[model_class])
94
+ model_cls = getattr(transformers_module, model_class)
95
+ model = model_cls(config)
96
+ except ImportError:
97
+ raise ImportError(
98
+ f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
99
+ " set `--only_pretrain_model` or `args.only_pretrain_model=True`."
100
+ )
101
+ else:
102
+ model = MODEL_MAPPING[config.__class__](config)
103
+
104
+ model.eval()
105
+ model.to(self.args.device)
106
+
107
+ # encoder-decoder has vocab size saved differently
108
+ vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
109
+ input_ids = torch.randint(vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device)
110
+
111
+ if self.args.fp16:
112
+ logger.info("Running training in Mixed Precision...")
113
+ if not self.args.is_gpu:
114
+ raise ValueError("Mixed precision is possible only for GPU.")
115
+ # amp seems to have memory leaks so that memory usage
116
+ # is measured using .half() for now https://github.com/NVIDIA/apex/issues/439
117
+ model.half()
118
+
119
+ if self.args.torchscript:
120
+ with torch.no_grad():
121
+ inference_model = torch.jit.trace(model, input_ids)
122
+ else:
123
+ inference_model = model
124
+
125
+ def encoder_decoder_forward():
126
+ with torch.no_grad():
127
+ outputs = inference_model(input_ids, decoder_input_ids=input_ids)
128
+ return outputs
129
+
130
+ def encoder_forward():
131
+ with torch.no_grad():
132
+ outputs = inference_model(input_ids)
133
+ return outputs
134
+
135
+ _forward = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward
136
+ return _forward
137
+
138
+ def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
139
+ config = self.config_dict[model_name]
140
+
141
+ has_model_class_in_config = (
142
+ hasattr(config, "architectures")
143
+ and isinstance(config.architectures, list)
144
+ and len(config.architectures) > 0
145
+ )
146
+ if not self.args.only_pretrain_model and has_model_class_in_config:
147
+ try:
148
+ model_class = config.architectures[0]
149
+ transformers_module = __import__("transformers", fromlist=[model_class])
150
+ model_cls = getattr(transformers_module, model_class)
151
+ model = model_cls(config)
152
+ except ImportError:
153
+ raise ImportError(
154
+ f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
155
+ " set `--only_pretrain_model` or `args.only_pretrain_model=True`."
156
+ )
157
+ else:
158
+ model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
159
+
160
+ if self.args.torchscript:
161
+ raise NotImplementedError("Training for torchscript is currently not implemented")
162
+ else:
163
+ train_model = model
164
+
165
+ model.train()
166
+ model.to(self.args.device)
167
+
168
+ # encoder-decoder has vocab size saved differently
169
+ vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
170
+ input_ids = torch.randint(vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device)
171
+
172
+ if self.args.fp16:
173
+ logger.info("Running training in Mixed Precision...")
174
+ if not self.args.is_gpu:
175
+ raise ValueError("Mixed precision is possible only for GPU.")
176
+
177
+ # amp seems to have memory leaks so that memory usage
178
+ # is measured using .half() for now https://github.com/NVIDIA/apex/issues/439
179
+ model.half()
180
+
181
+ def compute_loss_and_backprob_encoder():
182
+ loss = train_model(input_ids, labels=input_ids)[0]
183
+ loss.backward()
184
+ return loss
185
+
186
+ def compute_loss_and_backprob_encoder_decoder():
187
+ loss = train_model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0]
188
+ loss.backward()
189
+ return loss
190
+
191
+ _train = (
192
+ compute_loss_and_backprob_encoder_decoder
193
+ if config.is_encoder_decoder
194
+ else compute_loss_and_backprob_encoder
195
+ )
196
+ return _train
197
+
198
+ def _measure_speed(self, func) -> float:
199
+ try:
200
+ if self.args.is_tpu or self.args.torchscript:
201
+ # run additional 10 times to stabilize compilation for tpu and torchscript
202
+ logger.info("Do inference on TPU or torchscript. Running model 5 times to stabilize compilation")
203
+ timeit.repeat(
204
+ func,
205
+ repeat=1,
206
+ number=5,
207
+ )
208
+
209
+ # as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
210
+ runtimes = timeit.repeat(
211
+ func,
212
+ repeat=self.args.repeat,
213
+ number=10,
214
+ )
215
+
216
+ if self.args.is_tpu and self.args.torch_xla_tpu_print_metrics:
217
+ import torch_xla.debug.metrics as met
218
+
219
+ self.print_fn(met.metrics_report())
220
+
221
+ return min(runtimes) / 10.0
222
+ except RuntimeError as e:
223
+ self.print_fn(f"Doesn't fit on GPU. {e}")
224
+ return "N/A"
225
+
226
+ def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
227
+ try:
228
+ if self.args.trace_memory_line_by_line:
229
+ trace = start_memory_tracing("transformers")
230
+
231
+ if self.args.is_tpu:
232
+ # tpu
233
+ raise NotImplementedError(
234
+ "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with"
235
+ " `--no-memory` or `args.memory=False`"
236
+ )
237
+ elif self.args.is_gpu:
238
+ if not is_py3nvml_available():
239
+ logger.warning(
240
+ "py3nvml not installed, we won't log GPU memory usage. "
241
+ "Install py3nvml (pip install py3nvml) to log information about GPU."
242
+ )
243
+ memory = "N/A"
244
+ else:
245
+ logger.info(
246
+ "Measuring total GPU usage on GPU device. Make sure to not have additional processes running"
247
+ " on the same GPU."
248
+ )
249
+ # init nvml
250
+ nvml.nvmlInit()
251
+ func()
252
+ handle = nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx)
253
+ meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
254
+ max_bytes_in_use = meminfo.used
255
+ memory = Memory(max_bytes_in_use)
256
+ # shutdown nvml
257
+ nvml.nvmlShutdown()
258
+ else:
259
+ # cpu
260
+ memory_bytes = measure_peak_memory_cpu(func)
261
+ memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
262
+
263
+ if self.args.trace_memory_line_by_line:
264
+ summary = stop_memory_tracing(trace)
265
+ else:
266
+ summary = None
267
+
268
+ return memory, summary
269
+ except RuntimeError as e:
270
+ self.print_fn(f"Doesn't fit on GPU. {e}")
271
+ return "N/A", None
transformers_4_35_0/benchmark/benchmark_args.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from dataclasses import dataclass, field
18
+ from typing import Tuple
19
+
20
+ from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, requires_backends
21
+ from .benchmark_args_utils import BenchmarkArguments
22
+
23
+
24
+ if is_torch_available():
25
+ import torch
26
+
27
+ if is_torch_tpu_available(check_device=False):
28
+ import torch_xla.core.xla_model as xm
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ @dataclass
35
+ class PyTorchBenchmarkArguments(BenchmarkArguments):
36
+ deprecated_args = [
37
+ "no_inference",
38
+ "no_cuda",
39
+ "no_tpu",
40
+ "no_speed",
41
+ "no_memory",
42
+ "no_env_print",
43
+ "no_multi_process",
44
+ ]
45
+
46
+ def __init__(self, **kwargs):
47
+ """
48
+ This __init__ is there for legacy code. When removing deprecated args completely, the class can simply be
49
+ deleted
50
+ """
51
+ for deprecated_arg in self.deprecated_args:
52
+ if deprecated_arg in kwargs:
53
+ positive_arg = deprecated_arg[3:]
54
+ setattr(self, positive_arg, not kwargs.pop(deprecated_arg))
55
+ logger.warning(
56
+ f"{deprecated_arg} is depreciated. Please use --no_{positive_arg} or"
57
+ f" {positive_arg}={kwargs[positive_arg]}"
58
+ )
59
+
60
+ self.torchscript = kwargs.pop("torchscript", self.torchscript)
61
+ self.torch_xla_tpu_print_metrics = kwargs.pop("torch_xla_tpu_print_metrics", self.torch_xla_tpu_print_metrics)
62
+ self.fp16_opt_level = kwargs.pop("fp16_opt_level", self.fp16_opt_level)
63
+ super().__init__(**kwargs)
64
+
65
+ torchscript: bool = field(default=False, metadata={"help": "Trace the models using torchscript"})
66
+ torch_xla_tpu_print_metrics: bool = field(default=False, metadata={"help": "Print Xla/PyTorch tpu metrics"})
67
+ fp16_opt_level: str = field(
68
+ default="O1",
69
+ metadata={
70
+ "help": (
71
+ "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
72
+ "See details at https://nvidia.github.io/apex/amp.html"
73
+ )
74
+ },
75
+ )
76
+
77
+ @cached_property
78
+ def _setup_devices(self) -> Tuple["torch.device", int]:
79
+ requires_backends(self, ["torch"])
80
+ logger.info("PyTorch: setting up devices")
81
+ if not self.cuda:
82
+ device = torch.device("cpu")
83
+ n_gpu = 0
84
+ elif is_torch_tpu_available():
85
+ device = xm.xla_device()
86
+ n_gpu = 0
87
+ else:
88
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
+ n_gpu = torch.cuda.device_count()
90
+ return device, n_gpu
91
+
92
+ @property
93
+ def is_tpu(self):
94
+ return is_torch_tpu_available() and self.tpu
95
+
96
+ @property
97
+ def device_idx(self) -> int:
98
+ requires_backends(self, ["torch"])
99
+ # TODO(PVP): currently only single GPU is supported
100
+ return torch.cuda.current_device()
101
+
102
+ @property
103
+ def device(self) -> "torch.device":
104
+ requires_backends(self, ["torch"])
105
+ return self._setup_devices[0]
106
+
107
+ @property
108
+ def n_gpu(self):
109
+ requires_backends(self, ["torch"])
110
+ return self._setup_devices[1]
111
+
112
+ @property
113
+ def is_gpu(self):
114
+ return self.n_gpu > 0
transformers_4_35_0/benchmark/benchmark_args_tf.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from dataclasses import dataclass, field
18
+ from typing import Tuple
19
+
20
+ from ..utils import cached_property, is_tf_available, logging, requires_backends
21
+ from .benchmark_args_utils import BenchmarkArguments
22
+
23
+
24
+ if is_tf_available():
25
+ import tensorflow as tf
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ @dataclass
32
+ class TensorFlowBenchmarkArguments(BenchmarkArguments):
33
+ deprecated_args = [
34
+ "no_inference",
35
+ "no_cuda",
36
+ "no_tpu",
37
+ "no_speed",
38
+ "no_memory",
39
+ "no_env_print",
40
+ "no_multi_process",
41
+ ]
42
+
43
+ def __init__(self, **kwargs):
44
+ """
45
+ This __init__ is there for legacy code. When removing deprecated args completely, the class can simply be
46
+ deleted
47
+ """
48
+ for deprecated_arg in self.deprecated_args:
49
+ if deprecated_arg in kwargs:
50
+ positive_arg = deprecated_arg[3:]
51
+ kwargs[positive_arg] = not kwargs.pop(deprecated_arg)
52
+ logger.warning(
53
+ f"{deprecated_arg} is depreciated. Please use --no-{positive_arg} or"
54
+ f" {positive_arg}={kwargs[positive_arg]}"
55
+ )
56
+ self.tpu_name = kwargs.pop("tpu_name", self.tpu_name)
57
+ self.device_idx = kwargs.pop("device_idx", self.device_idx)
58
+ self.eager_mode = kwargs.pop("eager_mode", self.eager_mode)
59
+ self.use_xla = kwargs.pop("use_xla", self.use_xla)
60
+ super().__init__(**kwargs)
61
+
62
+ tpu_name: str = field(
63
+ default=None,
64
+ metadata={"help": "Name of TPU"},
65
+ )
66
+ device_idx: int = field(
67
+ default=0,
68
+ metadata={"help": "CPU / GPU device index. Defaults to 0."},
69
+ )
70
+ eager_mode: bool = field(default=False, metadata={"help": "Benchmark models in eager model."})
71
+ use_xla: bool = field(
72
+ default=False,
73
+ metadata={
74
+ "help": "Benchmark models using XLA JIT compilation. Note that `eager_model` has to be set to `False`."
75
+ },
76
+ )
77
+
78
+ @cached_property
79
+ def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]:
80
+ requires_backends(self, ["tf"])
81
+ tpu = None
82
+ if self.tpu:
83
+ try:
84
+ if self.tpu_name:
85
+ tpu = tf.distribute.cluster_resolver.TPUClusterResolver(self.tpu_name)
86
+ else:
87
+ tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
88
+ except ValueError:
89
+ tpu = None
90
+ return tpu
91
+
92
+ @cached_property
93
+ def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]:
94
+ requires_backends(self, ["tf"])
95
+ if self.is_tpu:
96
+ tf.config.experimental_connect_to_cluster(self._setup_tpu)
97
+ tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)
98
+
99
+ strategy = tf.distribute.TPUStrategy(self._setup_tpu)
100
+ else:
101
+ # currently no multi gpu is allowed
102
+ if self.is_gpu:
103
+ # TODO: Currently only single GPU is supported
104
+ tf.config.set_visible_devices(self.gpu_list[self.device_idx], "GPU")
105
+ strategy = tf.distribute.OneDeviceStrategy(device=f"/gpu:{self.device_idx}")
106
+ else:
107
+ tf.config.set_visible_devices([], "GPU") # disable GPU
108
+ strategy = tf.distribute.OneDeviceStrategy(device=f"/cpu:{self.device_idx}")
109
+
110
+ return strategy
111
+
112
+ @property
113
+ def is_tpu(self) -> bool:
114
+ requires_backends(self, ["tf"])
115
+ return self._setup_tpu is not None
116
+
117
+ @property
118
+ def strategy(self) -> "tf.distribute.Strategy":
119
+ requires_backends(self, ["tf"])
120
+ return self._setup_strategy
121
+
122
+ @property
123
+ def gpu_list(self):
124
+ requires_backends(self, ["tf"])
125
+ return tf.config.list_physical_devices("GPU")
126
+
127
+ @property
128
+ def n_gpu(self) -> int:
129
+ requires_backends(self, ["tf"])
130
+ if self.cuda:
131
+ return len(self.gpu_list)
132
+ return 0
133
+
134
+ @property
135
+ def is_gpu(self) -> bool:
136
+ return self.n_gpu > 0
transformers_4_35_0/benchmark/benchmark_args_utils.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import dataclasses
18
+ import json
19
+ import warnings
20
+ from dataclasses import dataclass, field
21
+ from time import time
22
+ from typing import List
23
+
24
+ from ..utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ def list_field(default=None, metadata=None):
31
+ return field(default_factory=lambda: default, metadata=metadata)
32
+
33
+
34
+ @dataclass
35
+ class BenchmarkArguments:
36
+ """
37
+ BenchMarkArguments are arguments we use in our benchmark scripts **which relate to the training loop itself**.
38
+
39
+ Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command
40
+ line.
41
+ """
42
+
43
+ models: List[str] = list_field(
44
+ default=[],
45
+ metadata={
46
+ "help": (
47
+ "Model checkpoints to be provided to the AutoModel classes. Leave blank to benchmark the base version"
48
+ " of all available models"
49
+ )
50
+ },
51
+ )
52
+
53
+ batch_sizes: List[int] = list_field(
54
+ default=[8], metadata={"help": "List of batch sizes for which memory and time performance will be evaluated"}
55
+ )
56
+
57
+ sequence_lengths: List[int] = list_field(
58
+ default=[8, 32, 128, 512],
59
+ metadata={"help": "List of sequence lengths for which memory and time performance will be evaluated"},
60
+ )
61
+
62
+ inference: bool = field(
63
+ default=True,
64
+ metadata={"help": "Whether to benchmark inference of model. Inference can be disabled via --no-inference."},
65
+ )
66
+ cuda: bool = field(
67
+ default=True,
68
+ metadata={"help": "Whether to run on available cuda devices. Cuda can be disabled via --no-cuda."},
69
+ )
70
+ tpu: bool = field(
71
+ default=True, metadata={"help": "Whether to run on available tpu devices. TPU can be disabled via --no-tpu."}
72
+ )
73
+ fp16: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."})
74
+ training: bool = field(default=False, metadata={"help": "Benchmark training of model"})
75
+ verbose: bool = field(default=False, metadata={"help": "Verbose memory tracing"})
76
+ speed: bool = field(
77
+ default=True,
78
+ metadata={"help": "Whether to perform speed measurements. Speed measurements can be disabled via --no-speed."},
79
+ )
80
+ memory: bool = field(
81
+ default=True,
82
+ metadata={
83
+ "help": "Whether to perform memory measurements. Memory measurements can be disabled via --no-memory"
84
+ },
85
+ )
86
+ trace_memory_line_by_line: bool = field(default=False, metadata={"help": "Trace memory line by line"})
87
+ save_to_csv: bool = field(default=False, metadata={"help": "Save result to a CSV file"})
88
+ log_print: bool = field(default=False, metadata={"help": "Save all print statements in a log file"})
89
+ env_print: bool = field(default=False, metadata={"help": "Whether to print environment information"})
90
+ multi_process: bool = field(
91
+ default=True,
92
+ metadata={
93
+ "help": (
94
+ "Whether to use multiprocessing for memory and speed measurement. It is highly recommended to use"
95
+ " multiprocessing for accurate CPU and GPU memory measurements. This option should only be disabled"
96
+ " for debugging / testing and on TPU."
97
+ )
98
+ },
99
+ )
100
+ inference_time_csv_file: str = field(
101
+ default=f"inference_time_{round(time())}.csv",
102
+ metadata={"help": "CSV filename used if saving time results to csv."},
103
+ )
104
+ inference_memory_csv_file: str = field(
105
+ default=f"inference_memory_{round(time())}.csv",
106
+ metadata={"help": "CSV filename used if saving memory results to csv."},
107
+ )
108
+ train_time_csv_file: str = field(
109
+ default=f"train_time_{round(time())}.csv",
110
+ metadata={"help": "CSV filename used if saving time results to csv for training."},
111
+ )
112
+ train_memory_csv_file: str = field(
113
+ default=f"train_memory_{round(time())}.csv",
114
+ metadata={"help": "CSV filename used if saving memory results to csv for training."},
115
+ )
116
+ env_info_csv_file: str = field(
117
+ default=f"env_info_{round(time())}.csv",
118
+ metadata={"help": "CSV filename used if saving environment information."},
119
+ )
120
+ log_filename: str = field(
121
+ default=f"log_{round(time())}.csv",
122
+ metadata={"help": "Log filename used if print statements are saved in log."},
123
+ )
124
+ repeat: int = field(default=3, metadata={"help": "Times an experiment will be run."})
125
+ only_pretrain_model: bool = field(
126
+ default=False,
127
+ metadata={
128
+ "help": (
129
+ "Instead of loading the model as defined in `config.architectures` if exists, just load the pretrain"
130
+ " model weights."
131
+ )
132
+ },
133
+ )
134
+
135
+ def __post_init__(self):
136
+ warnings.warn(
137
+ f"The class {self.__class__} is deprecated. Hugging Face Benchmarking utils"
138
+ " are deprecated in general and it is advised to use external Benchmarking libraries "
139
+ " to benchmark Transformer models.",
140
+ FutureWarning,
141
+ )
142
+
143
+ def to_json_string(self):
144
+ """
145
+ Serializes this instance to a JSON string.
146
+ """
147
+ return json.dumps(dataclasses.asdict(self), indent=2)
148
+
149
+ @property
150
+ def model_names(self) -> List[str]:
151
+ if len(self.models) <= 0:
152
+ raise ValueError(
153
+ "Please make sure you provide at least one model name / model identifier, *e.g.* `--models"
154
+ " bert-base-cased` or `args.models = ['bert-base-cased']."
155
+ )
156
+ return self.models
157
+
158
+ @property
159
+ def do_multi_processing(self):
160
+ if not self.multi_process:
161
+ return False
162
+ elif self.is_tpu:
163
+ logger.info("Multiprocessing is currently not possible on TPU.")
164
+ return False
165
+ else:
166
+ return True
transformers_4_35_0/benchmark/benchmark_tf.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Benchmarking the library on inference and training in PyTorch.
18
+ """
19
+
20
+
21
+ import random
22
+ import timeit
23
+ from functools import wraps
24
+ from typing import Callable, Optional
25
+
26
+ from ..configuration_utils import PretrainedConfig
27
+ from ..models.auto.modeling_tf_auto import TF_MODEL_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING
28
+ from ..utils import is_py3nvml_available, is_tf_available, logging
29
+ from .benchmark_utils import (
30
+ Benchmark,
31
+ Memory,
32
+ MemorySummary,
33
+ measure_peak_memory_cpu,
34
+ start_memory_tracing,
35
+ stop_memory_tracing,
36
+ )
37
+
38
+
39
+ if is_tf_available():
40
+ import tensorflow as tf
41
+ from tensorflow.python.framework.errors_impl import ResourceExhaustedError
42
+
43
+ from .benchmark_args_tf import TensorFlowBenchmarkArguments
44
+
45
+ if is_py3nvml_available():
46
+ import py3nvml.py3nvml as nvml
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool):
52
+ def run_func(func):
53
+ @wraps(func)
54
+ def run_in_eager_mode(*args, **kwargs):
55
+ return func(*args, **kwargs)
56
+
57
+ @wraps(func)
58
+ @tf.function(experimental_compile=use_xla)
59
+ def run_in_graph_mode(*args, **kwargs):
60
+ return func(*args, **kwargs)
61
+
62
+ if do_eager_mode is True:
63
+ if use_xla is not False:
64
+ raise ValueError(
65
+ "Cannot run model in XLA, if `args.eager_mode` is set to `True`. Please set `args.eager_mode=False`."
66
+ )
67
+ return run_in_eager_mode
68
+ else:
69
+ return run_in_graph_mode
70
+
71
+ return run_func
72
+
73
+
74
+ def random_input_ids(batch_size: int, sequence_length: int, vocab_size: int) -> ["tf.Tensor"]:
75
+ rng = random.Random()
76
+ values = [rng.randint(0, vocab_size - 1) for i in range(batch_size * sequence_length)]
77
+ return tf.constant(values, shape=(batch_size, sequence_length), dtype=tf.int32)
78
+
79
+
80
+ class TensorFlowBenchmark(Benchmark):
81
+ args: TensorFlowBenchmarkArguments
82
+ configs: PretrainedConfig
83
+ framework: str = "TensorFlow"
84
+
85
+ @property
86
+ def framework_version(self):
87
+ return tf.__version__
88
+
89
+ def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
90
+ # initialize GPU on separate process
91
+ strategy = self.args.strategy
92
+ if strategy is None:
93
+ raise ValueError("A device strategy has to be initialized before using TensorFlow.")
94
+ _inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
95
+ return self._measure_speed(_inference)
96
+
97
+ def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
98
+ strategy = self.args.strategy
99
+ if strategy is None:
100
+ raise ValueError("A device strategy has to be initialized before using TensorFlow.")
101
+ _train = self._prepare_train_func(model_name, batch_size, sequence_length)
102
+ return self._measure_speed(_train)
103
+
104
+ def _inference_memory(
105
+ self, model_name: str, batch_size: int, sequence_length: int
106
+ ) -> [Memory, Optional[MemorySummary]]:
107
+ # initialize GPU on separate process
108
+ if self.args.is_gpu:
109
+ tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
110
+ strategy = self.args.strategy
111
+ if strategy is None:
112
+ raise ValueError("A device strategy has to be initialized before using TensorFlow.")
113
+ _inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
114
+ return self._measure_memory(_inference)
115
+
116
+ def _train_memory(
117
+ self, model_name: str, batch_size: int, sequence_length: int
118
+ ) -> [Memory, Optional[MemorySummary]]:
119
+ if self.args.is_gpu:
120
+ tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
121
+ strategy = self.args.strategy
122
+ if strategy is None:
123
+ raise ValueError("A device strategy has to be initialized before using TensorFlow.")
124
+
125
+ _train = self._prepare_train_func(model_name, batch_size, sequence_length)
126
+ return self._measure_memory(_train)
127
+
128
+ def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
129
+ config = self.config_dict[model_name]
130
+
131
+ if self.args.fp16:
132
+ raise NotImplementedError("Mixed precision is currently not supported.")
133
+
134
+ has_model_class_in_config = (
135
+ hasattr(config, "architectures")
136
+ and isinstance(config.architectures, list)
137
+ and len(config.architectures) > 0
138
+ )
139
+ if not self.args.only_pretrain_model and has_model_class_in_config:
140
+ try:
141
+ model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
142
+ transformers_module = __import__("transformers", fromlist=[model_class])
143
+ model_cls = getattr(transformers_module, model_class)
144
+ model = model_cls(config)
145
+ except ImportError:
146
+ raise ImportError(
147
+ f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
148
+ " set `--only_pretrain_model` or `args.only_pretrain_model=True`."
149
+ )
150
+ else:
151
+ model = TF_MODEL_MAPPING[config.__class__](config)
152
+
153
+ # encoder-decoder has vocab size saved differently
154
+ vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
155
+ input_ids = random_input_ids(batch_size, sequence_length, vocab_size)
156
+
157
+ @run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
158
+ def encoder_decoder_forward():
159
+ return model(input_ids, decoder_input_ids=input_ids, training=False)
160
+
161
+ @run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
162
+ def encoder_forward():
163
+ return model(input_ids, training=False)
164
+
165
+ _inference = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward
166
+
167
+ return _inference
168
+
169
+ def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
170
+ config = self.config_dict[model_name]
171
+
172
+ if self.args.eager_mode is not False:
173
+ raise ValueError("Training cannot be done in eager mode. Please make sure that `args.eager_mode = False`.")
174
+
175
+ if self.args.fp16:
176
+ raise NotImplementedError("Mixed precision is currently not supported.")
177
+
178
+ has_model_class_in_config = (
179
+ hasattr(config, "architectures")
180
+ and isinstance(config.architectures, list)
181
+ and len(config.architectures) > 0
182
+ )
183
+ if not self.args.only_pretrain_model and has_model_class_in_config:
184
+ try:
185
+ model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
186
+ transformers_module = __import__("transformers", fromlist=[model_class])
187
+ model_cls = getattr(transformers_module, model_class)
188
+ model = model_cls(config)
189
+ except ImportError:
190
+ raise ImportError(
191
+ f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
192
+ " set `--only_pretrain_model` or `args.only_pretrain_model=True`."
193
+ )
194
+ else:
195
+ model = TF_MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
196
+
197
+ # encoder-decoder has vocab size saved differently
198
+ vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
199
+ input_ids = random_input_ids(batch_size, sequence_length, vocab_size)
200
+
201
+ @run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
202
+ def encoder_decoder_train():
203
+ loss = model(input_ids, decoder_input_ids=input_ids, labels=input_ids, training=True)[0]
204
+ gradients = tf.gradients(loss, model.trainable_variables)
205
+ return gradients
206
+
207
+ @run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
208
+ def encoder_train():
209
+ loss = model(input_ids, labels=input_ids, training=True)[0]
210
+ gradients = tf.gradients(loss, model.trainable_variables)
211
+ return gradients
212
+
213
+ _train = encoder_decoder_train if config.is_encoder_decoder else encoder_train
214
+
215
+ return _train
216
+
217
+ def _measure_speed(self, func) -> float:
218
+ with self.args.strategy.scope():
219
+ try:
220
+ if self.args.is_tpu or self.args.use_xla:
221
+ # run additional 10 times to stabilize compilation for tpu
222
+ logger.info("Do inference on TPU. Running model 5 times to stabilize compilation")
223
+ timeit.repeat(func, repeat=1, number=5)
224
+
225
+ # as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
226
+ runtimes = timeit.repeat(
227
+ func,
228
+ repeat=self.args.repeat,
229
+ number=10,
230
+ )
231
+
232
+ return min(runtimes) / 10.0
233
+ except ResourceExhaustedError as e:
234
+ self.print_fn(f"Doesn't fit on GPU. {e}")
235
+
236
+ def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
237
+ logger.info(
238
+ "Note that TensorFlow allocates more memory than "
239
+ "it might need to speed up computation. "
240
+ "The memory reported here corresponds to the memory "
241
+ "reported by `nvidia-smi`, which can vary depending "
242
+ "on total available memory on the GPU that is used."
243
+ )
244
+ with self.args.strategy.scope():
245
+ try:
246
+ if self.args.trace_memory_line_by_line:
247
+ if not self.args.eager_mode:
248
+ raise ValueError(
249
+ "`args.eager_mode` is set to `False`. Make sure to run model in eager mode to measure memory"
250
+ " consumption line by line."
251
+ )
252
+ trace = start_memory_tracing("transformers")
253
+
254
+ if self.args.is_tpu:
255
+ # tpu
256
+ raise NotImplementedError(
257
+ "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking"
258
+ " with `args.memory=False`"
259
+ )
260
+ elif self.args.is_gpu:
261
+ # gpu
262
+ if not is_py3nvml_available():
263
+ logger.warning(
264
+ "py3nvml not installed, we won't log GPU memory usage. "
265
+ "Install py3nvml (pip install py3nvml) to log information about GPU."
266
+ )
267
+ memory = "N/A"
268
+ else:
269
+ logger.info(
270
+ "Measuring total GPU usage on GPU device. Make sure to not have additional processes"
271
+ " running on the same GPU."
272
+ )
273
+ # init nvml
274
+ nvml.nvmlInit()
275
+ func()
276
+ handle = nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx)
277
+ meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
278
+ max_bytes_in_use = meminfo.used
279
+ memory = Memory(max_bytes_in_use)
280
+ # shutdown nvml
281
+ nvml.nvmlShutdown()
282
+ else:
283
+ # cpu
284
+ if self.args.trace_memory_line_by_line:
285
+ logger.info(
286
+ "When enabling line by line tracing, the max peak memory for CPU is inaccurate in"
287
+ " TensorFlow."
288
+ )
289
+ memory = None
290
+ else:
291
+ memory_bytes = measure_peak_memory_cpu(func)
292
+ memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
293
+ if self.args.trace_memory_line_by_line:
294
+ summary = stop_memory_tracing(trace)
295
+ if memory is None:
296
+ memory = summary.total
297
+ else:
298
+ summary = None
299
+
300
+ return memory, summary
301
+ except ResourceExhaustedError as e:
302
+ self.print_fn(f"Doesn't fit on GPU. {e}")
303
+ return "N/A", None
transformers_4_35_0/benchmark/benchmark_utils.py ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
2
+
3
+ # Copyright 2020 The HuggingFace Team and the AllenNLP authors. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Utilities for working with the local dataset cache.
18
+ """
19
+
20
+ import copy
21
+ import csv
22
+ import linecache
23
+ import os
24
+ import platform
25
+ import sys
26
+ import warnings
27
+ from abc import ABC, abstractmethod
28
+ from collections import defaultdict, namedtuple
29
+ from datetime import datetime
30
+ from multiprocessing import Pipe, Process, Queue
31
+ from multiprocessing.connection import Connection
32
+ from typing import Callable, Iterable, List, NamedTuple, Optional, Union
33
+
34
+ from .. import AutoConfig, PretrainedConfig
35
+ from .. import __version__ as version
36
+ from ..utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available, logging
37
+ from .benchmark_args_utils import BenchmarkArguments
38
+
39
+
40
+ if is_torch_available():
41
+ from torch.cuda import empty_cache as torch_empty_cache
42
+
43
+ if is_tf_available():
44
+ from tensorflow.python.eager import context as tf_context
45
+
46
+ if is_psutil_available():
47
+ import psutil
48
+
49
+ if is_py3nvml_available():
50
+ import py3nvml.py3nvml as nvml
51
+
52
+ if platform.system() == "Windows":
53
+ from signal import CTRL_C_EVENT as SIGKILL
54
+ else:
55
+ from signal import SIGKILL
56
+
57
+
58
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
59
+
60
+
61
+ _is_memory_tracing_enabled = False
62
+
63
+ BenchmarkOutput = namedtuple(
64
+ "BenchmarkOutput",
65
+ [
66
+ "time_inference_result",
67
+ "memory_inference_result",
68
+ "time_train_result",
69
+ "memory_train_result",
70
+ "inference_summary",
71
+ "train_summary",
72
+ ],
73
+ )
74
+
75
+
76
+ def separate_process_wrapper_fn(func: Callable[[], None], do_multi_processing: bool) -> Callable[[], None]:
77
+ """
78
+ This function wraps another function into its own separated process. In order to ensure accurate memory
79
+ measurements it is important that the function is executed in a separate process
80
+
81
+ Args:
82
+ - `func`: (`callable`): function() -> ... generic function which will be executed in its own separate process
83
+ - `do_multi_processing`: (`bool`) Whether to run function on separate process or not
84
+ """
85
+
86
+ def multi_process_func(*args, **kwargs):
87
+ # run function in an individual
88
+ # process to get correct memory
89
+ def wrapper_func(queue: Queue, *args):
90
+ try:
91
+ result = func(*args)
92
+ except Exception as e:
93
+ logger.error(e)
94
+ print(e)
95
+ result = "N/A"
96
+ queue.put(result)
97
+
98
+ queue = Queue()
99
+ p = Process(target=wrapper_func, args=[queue] + list(args))
100
+ p.start()
101
+ result = queue.get()
102
+ p.join()
103
+ return result
104
+
105
+ if do_multi_processing:
106
+ logger.info(f"Function {func} is executed in its own process...")
107
+ return multi_process_func
108
+ else:
109
+ return func
110
+
111
+
112
+ def is_memory_tracing_enabled():
113
+ global _is_memory_tracing_enabled
114
+ return _is_memory_tracing_enabled
115
+
116
+
117
+ class Frame(NamedTuple):
118
+ """
119
+ `Frame` is a NamedTuple used to gather the current frame state. `Frame` has the following fields:
120
+
121
+ - 'filename' (string): Name of the file currently executed
122
+ - 'module' (string): Name of the module currently executed
123
+ - 'line_number' (int): Number of the line currently executed
124
+ - 'event' (string): Event that triggered the tracing (default will be "line")
125
+ - 'line_text' (string): Text of the line in the python script
126
+ """
127
+
128
+ filename: str
129
+ module: str
130
+ line_number: int
131
+ event: str
132
+ line_text: str
133
+
134
+
135
+ class UsedMemoryState(NamedTuple):
136
+ """
137
+ `UsedMemoryState` are named tuples with the following fields:
138
+
139
+ - 'frame': a `Frame` namedtuple (see below) storing information on the current tracing frame (current file,
140
+ location in current file)
141
+ - 'cpu_memory': CPU RSS memory state *before* executing the line
142
+ - 'gpu_memory': GPU used memory *before* executing the line (sum for all GPUs or for only `gpus_to_trace` if
143
+ provided)
144
+ """
145
+
146
+ frame: Frame
147
+ cpu_memory: int
148
+ gpu_memory: int
149
+
150
+
151
+ class Memory(NamedTuple):
152
+ """
153
+ `Memory` NamedTuple have a single field `bytes` and you can get a human readable str of the number of mega bytes by
154
+ calling `__repr__`
155
+
156
+ - `byte` (integer): number of bytes,
157
+ """
158
+
159
+ bytes: int
160
+
161
+ def __repr__(self) -> str:
162
+ return str(bytes_to_mega_bytes(self.bytes))
163
+
164
+
165
+ class MemoryState(NamedTuple):
166
+ """
167
+ `MemoryState` are namedtuples listing frame + CPU/GPU memory with the following fields:
168
+
169
+ - `frame` (`Frame`): the current frame (see above)
170
+ - `cpu`: CPU memory consumed at during the current frame as a `Memory` named tuple
171
+ - `gpu`: GPU memory consumed at during the current frame as a `Memory` named tuple
172
+ - `cpu_gpu`: CPU + GPU memory consumed at during the current frame as a `Memory` named tuple
173
+ """
174
+
175
+ frame: Frame
176
+ cpu: Memory
177
+ gpu: Memory
178
+ cpu_gpu: Memory
179
+
180
+
181
+ class MemorySummary(NamedTuple):
182
+ """
183
+ `MemorySummary` namedtuple otherwise with the fields:
184
+
185
+ - `sequential`: a list of `MemoryState` namedtuple (see below) computed from the provided `memory_trace` by
186
+ subtracting the memory after executing each line from the memory before executing said line.
187
+ - `cumulative`: a list of `MemoryState` namedtuple (see below) with cumulative increase in memory for each line
188
+ obtained by summing repeated memory increase for a line if it's executed several times. The list is sorted
189
+ from the frame with the largest memory consumption to the frame with the smallest (can be negative if memory
190
+ is released)
191
+ - `total`: total memory increase during the full tracing as a `Memory` named tuple (see below). Line with
192
+ memory release (negative consumption) are ignored if `ignore_released_memory` is `True` (default).
193
+ """
194
+
195
+ sequential: List[MemoryState]
196
+ cumulative: List[MemoryState]
197
+ current: List[MemoryState]
198
+ total: Memory
199
+
200
+
201
+ MemoryTrace = List[UsedMemoryState]
202
+
203
+
204
+ def measure_peak_memory_cpu(function: Callable[[], None], interval=0.5, device_idx=None) -> int:
205
+ """
206
+ measures peak cpu memory consumption of a given `function` running the function for at least interval seconds and
207
+ at most 20 * interval seconds. This function is heavily inspired by: `memory_usage` of the package
208
+ `memory_profiler`:
209
+ https://github.com/pythonprofilers/memory_profiler/blob/895c4ac7a08020d66ae001e24067da6dcea42451/memory_profiler.py#L239
210
+
211
+ Args:
212
+ - `function`: (`callable`): function() -> ... function without any arguments to measure for which to measure
213
+ the peak memory
214
+
215
+ - `interval`: (`float`, `optional`, defaults to `0.5`) interval in second for which to measure the memory usage
216
+
217
+ - `device_idx`: (`int`, `optional`, defaults to `None`) device id for which to measure gpu usage
218
+
219
+ Returns:
220
+
221
+ - `max_memory`: (`int`) consumed memory peak in Bytes
222
+ """
223
+
224
+ def get_cpu_memory(process_id: int) -> int:
225
+ """
226
+ measures current cpu memory usage of a given `process_id`
227
+
228
+ Args:
229
+ - `process_id`: (`int`) process_id for which to measure memory
230
+
231
+ Returns
232
+
233
+ - `memory`: (`int`) consumed memory in Bytes
234
+ """
235
+ process = psutil.Process(process_id)
236
+ try:
237
+ meminfo_attr = "memory_info" if hasattr(process, "memory_info") else "get_memory_info"
238
+ memory = getattr(process, meminfo_attr)()[0]
239
+ except psutil.AccessDenied:
240
+ raise ValueError("Error with Psutil.")
241
+ return memory
242
+
243
+ if not is_psutil_available():
244
+ logger.warning(
245
+ "Psutil not installed, we won't log CPU memory usage. "
246
+ "Install Psutil (pip install psutil) to use CPU memory tracing."
247
+ )
248
+ max_memory = "N/A"
249
+ else:
250
+
251
+ class MemoryMeasureProcess(Process):
252
+
253
+ """
254
+ `MemoryMeasureProcess` inherits from `Process` and overwrites its `run()` method. Used to measure the
255
+ memory usage of a process
256
+ """
257
+
258
+ def __init__(self, process_id: int, child_connection: Connection, interval: float):
259
+ super().__init__()
260
+ self.process_id = process_id
261
+ self.interval = interval
262
+ self.connection = child_connection
263
+ self.num_measurements = 1
264
+ self.mem_usage = get_cpu_memory(self.process_id)
265
+
266
+ def run(self):
267
+ self.connection.send(0)
268
+ stop = False
269
+ while True:
270
+ self.mem_usage = max(self.mem_usage, get_cpu_memory(self.process_id))
271
+ self.num_measurements += 1
272
+
273
+ if stop:
274
+ break
275
+
276
+ stop = self.connection.poll(self.interval)
277
+
278
+ # send results to parent pipe
279
+ self.connection.send(self.mem_usage)
280
+ self.connection.send(self.num_measurements)
281
+
282
+ while True:
283
+ # create child, parent connection
284
+ child_connection, parent_connection = Pipe()
285
+
286
+ # instantiate process
287
+ mem_process = MemoryMeasureProcess(os.getpid(), child_connection, interval)
288
+ mem_process.start()
289
+
290
+ # wait until we get memory
291
+ parent_connection.recv()
292
+
293
+ try:
294
+ # execute function
295
+ function()
296
+
297
+ # start parent connection
298
+ parent_connection.send(0)
299
+
300
+ # receive memory and num measurements
301
+ max_memory = parent_connection.recv()
302
+ num_measurements = parent_connection.recv()
303
+ except Exception:
304
+ # kill process in a clean way
305
+ parent = psutil.Process(os.getpid())
306
+ for child in parent.children(recursive=True):
307
+ os.kill(child.pid, SIGKILL)
308
+ mem_process.join(0)
309
+ raise RuntimeError("Process killed. Error in Process")
310
+
311
+ # run process at least 20 * interval or until it finishes
312
+ mem_process.join(20 * interval)
313
+
314
+ if (num_measurements > 4) or (interval < 1e-6):
315
+ break
316
+
317
+ # reduce interval
318
+ interval /= 10
319
+
320
+ return max_memory
321
+
322
+
323
+ def start_memory_tracing(
324
+ modules_to_trace: Optional[Union[str, Iterable[str]]] = None,
325
+ modules_not_to_trace: Optional[Union[str, Iterable[str]]] = None,
326
+ events_to_trace: str = "line",
327
+ gpus_to_trace: Optional[List[int]] = None,
328
+ ) -> MemoryTrace:
329
+ """
330
+ Setup line-by-line tracing to record rss mem (RAM) at each line of a module or sub-module. See `./benchmark.py` for
331
+ usage examples. Current memory consumption is returned using psutil and in particular is the RSS memory "Resident
332
+ Set Size” (the non-swapped physical memory the process is using). See
333
+ https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info
334
+
335
+ Args:
336
+ - `modules_to_trace`: (None, string, list/tuple of string) if None, all events are recorded if string or list
337
+ of strings: only events from the listed module/sub-module will be recorded (e.g. 'fairseq' or
338
+ 'transformers.models.gpt2.modeling_gpt2')
339
+ - `modules_not_to_trace`: (None, string, list/tuple of string) if None, no module is avoided if string or list
340
+ of strings: events from the listed module/sub-module will not be recorded (e.g. 'torch')
341
+ - `events_to_trace`: string or list of string of events to be recorded (see official python doc for
342
+ `sys.settrace` for the list of events) default to line
343
+ - `gpus_to_trace`: (optional list, default None) list of GPUs to trace. Default to tracing all GPUs
344
+
345
+ Return:
346
+
347
+ - `memory_trace` is a list of `UsedMemoryState` for each event (default each line of the traced script).
348
+
349
+ - `UsedMemoryState` are named tuples with the following fields:
350
+
351
+ - 'frame': a `Frame` namedtuple (see below) storing information on the current tracing frame (current
352
+ file, location in current file)
353
+ - 'cpu_memory': CPU RSS memory state *before* executing the line
354
+ - 'gpu_memory': GPU used memory *before* executing the line (sum for all GPUs or for only
355
+ `gpus_to_trace` if provided)
356
+
357
+ `Frame` is a namedtuple used by `UsedMemoryState` to list the current frame state. `Frame` has the following
358
+ fields: - 'filename' (string): Name of the file currently executed - 'module' (string): Name of the module
359
+ currently executed - 'line_number' (int): Number of the line currently executed - 'event' (string): Event that
360
+ triggered the tracing (default will be "line") - 'line_text' (string): Text of the line in the python script
361
+
362
+ """
363
+ if is_psutil_available():
364
+ process = psutil.Process(os.getpid())
365
+ else:
366
+ logger.warning(
367
+ "Psutil not installed, we won't log CPU memory usage. "
368
+ "Install psutil (pip install psutil) to use CPU memory tracing."
369
+ )
370
+ process = None
371
+
372
+ if is_py3nvml_available():
373
+ try:
374
+ nvml.nvmlInit()
375
+ devices = list(range(nvml.nvmlDeviceGetCount())) if gpus_to_trace is None else gpus_to_trace
376
+ nvml.nvmlShutdown()
377
+ except (OSError, nvml.NVMLError):
378
+ logger.warning("Error while initializing communication with GPU. We won't perform GPU memory tracing.")
379
+ log_gpu = False
380
+ else:
381
+ log_gpu = is_torch_available() or is_tf_available()
382
+ else:
383
+ logger.warning(
384
+ "py3nvml not installed, we won't log GPU memory usage. "
385
+ "Install py3nvml (pip install py3nvml) to use GPU memory tracing."
386
+ )
387
+ log_gpu = False
388
+
389
+ memory_trace = []
390
+
391
+ def traceit(frame, event, args):
392
+ """
393
+ Tracing method executed before running each line in a module or sub-module Record memory allocated in a list
394
+ with debugging information
395
+ """
396
+ global _is_memory_tracing_enabled
397
+
398
+ if not _is_memory_tracing_enabled:
399
+ return traceit
400
+
401
+ # Filter events
402
+ if events_to_trace is not None:
403
+ if isinstance(events_to_trace, str) and event != events_to_trace:
404
+ return traceit
405
+ elif isinstance(events_to_trace, (list, tuple)) and event not in events_to_trace:
406
+ return traceit
407
+
408
+ if "__name__" not in frame.f_globals:
409
+ return traceit
410
+
411
+ # Filter modules
412
+ name = frame.f_globals["__name__"]
413
+ if not isinstance(name, str):
414
+ return traceit
415
+ else:
416
+ # Filter whitelist of modules to trace
417
+ if modules_to_trace is not None:
418
+ if isinstance(modules_to_trace, str) and modules_to_trace not in name:
419
+ return traceit
420
+ elif isinstance(modules_to_trace, (list, tuple)) and all(m not in name for m in modules_to_trace):
421
+ return traceit
422
+
423
+ # Filter blacklist of modules not to trace
424
+ if modules_not_to_trace is not None:
425
+ if isinstance(modules_not_to_trace, str) and modules_not_to_trace in name:
426
+ return traceit
427
+ elif isinstance(modules_not_to_trace, (list, tuple)) and any(m in name for m in modules_not_to_trace):
428
+ return traceit
429
+
430
+ # Record current tracing state (file, location in file...)
431
+ lineno = frame.f_lineno
432
+ filename = frame.f_globals["__file__"]
433
+ if filename.endswith(".pyc") or filename.endswith(".pyo"):
434
+ filename = filename[:-1]
435
+ line = linecache.getline(filename, lineno).rstrip()
436
+ traced_state = Frame(filename, name, lineno, event, line)
437
+
438
+ # Record current memory state (rss memory) and compute difference with previous memory state
439
+ cpu_mem = 0
440
+ if process is not None:
441
+ mem = process.memory_info()
442
+ cpu_mem = mem.rss
443
+
444
+ gpu_mem = 0
445
+ if log_gpu:
446
+ # Clear GPU caches
447
+ if is_torch_available():
448
+ torch_empty_cache()
449
+ if is_tf_available():
450
+ tf_context.context()._clear_caches() # See https://github.com/tensorflow/tensorflow/issues/20218#issuecomment-416771802
451
+
452
+ # Sum used memory for all GPUs
453
+ nvml.nvmlInit()
454
+
455
+ for i in devices:
456
+ handle = nvml.nvmlDeviceGetHandleByIndex(i)
457
+ meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
458
+ gpu_mem += meminfo.used
459
+
460
+ nvml.nvmlShutdown()
461
+
462
+ mem_state = UsedMemoryState(traced_state, cpu_mem, gpu_mem)
463
+ memory_trace.append(mem_state)
464
+
465
+ return traceit
466
+
467
+ sys.settrace(traceit)
468
+
469
+ global _is_memory_tracing_enabled
470
+ _is_memory_tracing_enabled = True
471
+
472
+ return memory_trace
473
+
474
+
475
+ def stop_memory_tracing(
476
+ memory_trace: Optional[MemoryTrace] = None, ignore_released_memory: bool = True
477
+ ) -> Optional[MemorySummary]:
478
+ """
479
+ Stop memory tracing cleanly and return a summary of the memory trace if a trace is given.
480
+
481
+ Args:
482
+ `memory_trace` (optional output of start_memory_tracing, default: None):
483
+ memory trace to convert in summary
484
+ `ignore_released_memory` (boolean, default: None):
485
+ if True we only sum memory increase to compute total memory
486
+
487
+ Return:
488
+
489
+ - None if `memory_trace` is None
490
+ - `MemorySummary` namedtuple otherwise with the fields:
491
+
492
+ - `sequential`: a list of `MemoryState` namedtuple (see below) computed from the provided `memory_trace` by
493
+ subtracting the memory after executing each line from the memory before executing said line.
494
+ - `cumulative`: a list of `MemoryState` namedtuple (see below) with cumulative increase in memory for each
495
+ line obtained by summing repeated memory increase for a line if it's executed several times. The list is
496
+ sorted from the frame with the largest memory consumption to the frame with the smallest (can be negative
497
+ if memory is released)
498
+ - `total`: total memory increase during the full tracing as a `Memory` named tuple (see below). Line with
499
+ memory release (negative consumption) are ignored if `ignore_released_memory` is `True` (default).
500
+
501
+ `Memory` named tuple have fields
502
+
503
+ - `byte` (integer): number of bytes,
504
+ - `string` (string): same as human readable string (ex: "3.5MB")
505
+
506
+ `Frame` are namedtuple used to list the current frame state and have the following fields:
507
+
508
+ - 'filename' (string): Name of the file currently executed
509
+ - 'module' (string): Name of the module currently executed
510
+ - 'line_number' (int): Number of the line currently executed
511
+ - 'event' (string): Event that triggered the tracing (default will be "line")
512
+ - 'line_text' (string): Text of the line in the python script
513
+
514
+ `MemoryState` are namedtuples listing frame + CPU/GPU memory with the following fields:
515
+
516
+ - `frame` (`Frame`): the current frame (see above)
517
+ - `cpu`: CPU memory consumed at during the current frame as a `Memory` named tuple
518
+ - `gpu`: GPU memory consumed at during the current frame as a `Memory` named tuple
519
+ - `cpu_gpu`: CPU + GPU memory consumed at during the current frame as a `Memory` named tuple
520
+ """
521
+ global _is_memory_tracing_enabled
522
+ _is_memory_tracing_enabled = False
523
+
524
+ if memory_trace is not None and len(memory_trace) > 1:
525
+ memory_diff_trace = []
526
+ memory_curr_trace = []
527
+
528
+ cumulative_memory_dict = defaultdict(lambda: [0, 0, 0])
529
+
530
+ for (
531
+ (frame, cpu_mem, gpu_mem),
532
+ (next_frame, next_cpu_mem, next_gpu_mem),
533
+ ) in zip(memory_trace[:-1], memory_trace[1:]):
534
+ cpu_mem_inc = next_cpu_mem - cpu_mem
535
+ gpu_mem_inc = next_gpu_mem - gpu_mem
536
+ cpu_gpu_mem_inc = cpu_mem_inc + gpu_mem_inc
537
+ memory_diff_trace.append(
538
+ MemoryState(
539
+ frame=frame,
540
+ cpu=Memory(cpu_mem_inc),
541
+ gpu=Memory(gpu_mem_inc),
542
+ cpu_gpu=Memory(cpu_gpu_mem_inc),
543
+ )
544
+ )
545
+
546
+ memory_curr_trace.append(
547
+ MemoryState(
548
+ frame=frame,
549
+ cpu=Memory(next_cpu_mem),
550
+ gpu=Memory(next_gpu_mem),
551
+ cpu_gpu=Memory(next_gpu_mem + next_cpu_mem),
552
+ )
553
+ )
554
+
555
+ cumulative_memory_dict[frame][0] += cpu_mem_inc
556
+ cumulative_memory_dict[frame][1] += gpu_mem_inc
557
+ cumulative_memory_dict[frame][2] += cpu_gpu_mem_inc
558
+
559
+ cumulative_memory = sorted(
560
+ cumulative_memory_dict.items(), key=lambda x: x[1][2], reverse=True
561
+ ) # order by the total CPU + GPU memory increase
562
+ cumulative_memory = [
563
+ MemoryState(
564
+ frame=frame,
565
+ cpu=Memory(cpu_mem_inc),
566
+ gpu=Memory(gpu_mem_inc),
567
+ cpu_gpu=Memory(cpu_gpu_mem_inc),
568
+ )
569
+ for frame, (cpu_mem_inc, gpu_mem_inc, cpu_gpu_mem_inc) in cumulative_memory
570
+ ]
571
+
572
+ memory_curr_trace = sorted(memory_curr_trace, key=lambda x: x.cpu_gpu.bytes, reverse=True)
573
+
574
+ if ignore_released_memory:
575
+ total_memory = sum(max(0, step_trace.cpu_gpu.bytes) for step_trace in memory_diff_trace)
576
+ else:
577
+ total_memory = sum(step_trace.cpu_gpu.bytes for step_trace in memory_diff_trace)
578
+
579
+ total_memory = Memory(total_memory)
580
+
581
+ return MemorySummary(
582
+ sequential=memory_diff_trace,
583
+ cumulative=cumulative_memory,
584
+ current=memory_curr_trace,
585
+ total=total_memory,
586
+ )
587
+
588
+ return None
589
+
590
+
591
+ def bytes_to_mega_bytes(memory_amount: int) -> int:
592
+ """Utility to convert a number of bytes (int) into a number of mega bytes (int)"""
593
+ return memory_amount >> 20
594
+
595
+
596
+ class Benchmark(ABC):
597
+ """
598
+ Benchmarks is a simple but feature-complete benchmarking script to compare memory and time performance of models in
599
+ Transformers.
600
+ """
601
+
602
+ args: BenchmarkArguments
603
+ configs: PretrainedConfig
604
+ framework: str
605
+
606
+ def __init__(self, args: BenchmarkArguments = None, configs: PretrainedConfig = None):
607
+ self.args = args
608
+ if configs is None:
609
+ self.config_dict = {
610
+ model_name: AutoConfig.from_pretrained(model_name) for model_name in self.args.model_names
611
+ }
612
+ else:
613
+ self.config_dict = dict(zip(self.args.model_names, configs))
614
+
615
+ warnings.warn(
616
+ f"The class {self.__class__} is deprecated. Hugging Face Benchmarking utils"
617
+ " are deprecated in general and it is advised to use external Benchmarking libraries "
618
+ " to benchmark Transformer models.",
619
+ FutureWarning,
620
+ )
621
+
622
+ if self.args.memory and os.getenv("TRANSFORMERS_USE_MULTIPROCESSING") == 0:
623
+ logger.warning(
624
+ "Memory consumption will not be measured accurately if `args.multi_process` is set to `False.` The"
625
+ " flag 'TRANSFORMERS_USE_MULTIPROCESSING' should only be disabled for debugging / testing."
626
+ )
627
+
628
+ self._print_fn = None
629
+ self._framework_version = None
630
+ self._environment_info = None
631
+
632
+ @property
633
+ def print_fn(self):
634
+ if self._print_fn is None:
635
+ if self.args.log_print:
636
+
637
+ def print_and_log(*args):
638
+ with open(self.args.log_filename, "a") as log_file:
639
+ log_file.write("".join(args) + "\n")
640
+ print(*args)
641
+
642
+ self._print_fn = print_and_log
643
+ else:
644
+ self._print_fn = print
645
+ return self._print_fn
646
+
647
+ @property
648
+ @abstractmethod
649
+ def framework_version(self):
650
+ pass
651
+
652
+ @abstractmethod
653
+ def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
654
+ pass
655
+
656
+ @abstractmethod
657
+ def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
658
+ pass
659
+
660
+ @abstractmethod
661
+ def _inference_memory(
662
+ self, model_name: str, batch_size: int, sequence_length: int
663
+ ) -> [Memory, Optional[MemorySummary]]:
664
+ pass
665
+
666
+ @abstractmethod
667
+ def _train_memory(
668
+ self, model_name: str, batch_size: int, sequence_length: int
669
+ ) -> [Memory, Optional[MemorySummary]]:
670
+ pass
671
+
672
+ def inference_speed(self, *args, **kwargs) -> float:
673
+ return separate_process_wrapper_fn(self._inference_speed, self.args.do_multi_processing)(*args, **kwargs)
674
+
675
+ def train_speed(self, *args, **kwargs) -> float:
676
+ return separate_process_wrapper_fn(self._train_speed, self.args.do_multi_processing)(*args, **kwargs)
677
+
678
+ def inference_memory(self, *args, **kwargs) -> [Memory, Optional[MemorySummary]]:
679
+ return separate_process_wrapper_fn(self._inference_memory, self.args.do_multi_processing)(*args, **kwargs)
680
+
681
+ def train_memory(self, *args, **kwargs) -> [Memory, Optional[MemorySummary]]:
682
+ return separate_process_wrapper_fn(self._train_memory, self.args.do_multi_processing)(*args, **kwargs)
683
+
684
+ def run(self):
685
+ result_dict = {model_name: {} for model_name in self.args.model_names}
686
+ inference_result_time = copy.deepcopy(result_dict)
687
+ inference_result_memory = copy.deepcopy(result_dict)
688
+ train_result_time = copy.deepcopy(result_dict)
689
+ train_result_memory = copy.deepcopy(result_dict)
690
+
691
+ for c, model_name in enumerate(self.args.model_names):
692
+ self.print_fn(f"{c + 1} / {len(self.args.model_names)}")
693
+
694
+ model_dict = {
695
+ "bs": self.args.batch_sizes,
696
+ "ss": self.args.sequence_lengths,
697
+ "result": {i: {} for i in self.args.batch_sizes},
698
+ }
699
+ inference_result_time[model_name] = copy.deepcopy(model_dict)
700
+ inference_result_memory[model_name] = copy.deepcopy(model_dict)
701
+ train_result_time[model_name] = copy.deepcopy(model_dict)
702
+ train_result_memory[model_name] = copy.deepcopy(model_dict)
703
+
704
+ inference_summary = train_summary = None
705
+
706
+ for batch_size in self.args.batch_sizes:
707
+ for sequence_length in self.args.sequence_lengths:
708
+ if self.args.inference:
709
+ if self.args.memory:
710
+ memory, inference_summary = self.inference_memory(model_name, batch_size, sequence_length)
711
+ inference_result_memory[model_name]["result"][batch_size][sequence_length] = memory
712
+ if self.args.speed:
713
+ time = self.inference_speed(model_name, batch_size, sequence_length)
714
+ inference_result_time[model_name]["result"][batch_size][sequence_length] = time
715
+
716
+ if self.args.training:
717
+ if self.args.memory:
718
+ memory, train_summary = self.train_memory(model_name, batch_size, sequence_length)
719
+ train_result_memory[model_name]["result"][batch_size][sequence_length] = memory
720
+ if self.args.speed:
721
+ time = self.train_speed(model_name, batch_size, sequence_length)
722
+ train_result_time[model_name]["result"][batch_size][sequence_length] = time
723
+
724
+ if self.args.inference:
725
+ if self.args.speed:
726
+ self.print_fn("\n" + 20 * "=" + ("INFERENCE - SPEED - RESULT").center(40) + 20 * "=")
727
+ self.print_results(inference_result_time, type_label="Time in s")
728
+ self.save_to_csv(inference_result_time, self.args.inference_time_csv_file)
729
+ if self.args.is_tpu:
730
+ self.print_fn(
731
+ "TPU was used for inference. Note that the time after compilation stabilized (after ~10"
732
+ " inferences model.forward(..) calls) was measured."
733
+ )
734
+
735
+ if self.args.memory:
736
+ self.print_fn("\n" + 20 * "=" + ("INFERENCE - MEMORY - RESULT").center(40) + 20 * "=")
737
+ self.print_results(inference_result_memory, type_label="Memory in MB")
738
+ self.save_to_csv(inference_result_memory, self.args.inference_memory_csv_file)
739
+
740
+ if self.args.trace_memory_line_by_line:
741
+ self.print_fn("\n" + 20 * "=" + ("INFERENCE - MEMOMRY - LINE BY LINE - SUMMARY").center(40) + 20 * "=")
742
+ self.print_memory_trace_statistics(inference_summary)
743
+
744
+ if self.args.training:
745
+ if self.args.speed:
746
+ self.print_fn("\n" + 20 * "=" + ("TRAIN - SPEED - RESULTS").center(40) + 20 * "=")
747
+ self.print_results(train_result_time, "Time in s")
748
+ self.save_to_csv(train_result_time, self.args.train_time_csv_file)
749
+ if self.args.is_tpu:
750
+ self.print_fn(
751
+ "TPU was used for training. Note that the time after compilation stabilized (after ~10 train"
752
+ " loss=model.forward(...) + loss.backward() calls) was measured."
753
+ )
754
+
755
+ if self.args.memory:
756
+ self.print_fn("\n" + 20 * "=" + ("TRAIN - MEMORY - RESULTS").center(40) + 20 * "=")
757
+ self.print_results(train_result_memory, type_label="Memory in MB")
758
+ self.save_to_csv(train_result_memory, self.args.train_memory_csv_file)
759
+
760
+ if self.args.trace_memory_line_by_line:
761
+ self.print_fn("\n" + 20 * "=" + ("TRAIN - MEMOMRY - LINE BY LINE - SUMMARY").center(40) + 20 * "=")
762
+ self.print_memory_trace_statistics(train_summary)
763
+
764
+ if self.args.env_print:
765
+ self.print_fn("\n" + 20 * "=" + ("ENVIRONMENT INFORMATION").center(40) + 20 * "=")
766
+ self.print_fn("\n".join([f"- {prop}: {val}" for prop, val in self.environment_info.items()]) + "\n")
767
+
768
+ if self.args.save_to_csv:
769
+ with open(self.args.env_info_csv_file, mode="w", newline="") as csv_file:
770
+ writer = csv.writer(csv_file)
771
+ for key, value in self.environment_info.items():
772
+ writer.writerow([key, value])
773
+
774
+ return BenchmarkOutput(
775
+ inference_result_time,
776
+ inference_result_memory,
777
+ train_result_time,
778
+ train_result_memory,
779
+ inference_summary,
780
+ train_summary,
781
+ )
782
+
783
+ @property
784
+ def environment_info(self):
785
+ if self._environment_info is None:
786
+ info = {}
787
+ info["transformers_version"] = version
788
+ info["framework"] = self.framework
789
+ if self.framework == "PyTorch":
790
+ info["use_torchscript"] = self.args.torchscript
791
+ if self.framework == "TensorFlow":
792
+ info["eager_mode"] = self.args.eager_mode
793
+ info["use_xla"] = self.args.use_xla
794
+ info["framework_version"] = self.framework_version
795
+ info["python_version"] = platform.python_version()
796
+ info["system"] = platform.system()
797
+ info["cpu"] = platform.processor()
798
+ info["architecture"] = platform.architecture()[0]
799
+ info["date"] = datetime.date(datetime.now())
800
+ info["time"] = datetime.time(datetime.now())
801
+ info["fp16"] = self.args.fp16
802
+ info["use_multiprocessing"] = self.args.do_multi_processing
803
+ info["only_pretrain_model"] = self.args.only_pretrain_model
804
+
805
+ if is_psutil_available():
806
+ info["cpu_ram_mb"] = bytes_to_mega_bytes(psutil.virtual_memory().total)
807
+ else:
808
+ logger.warning(
809
+ "Psutil not installed, we won't log available CPU memory. "
810
+ "Install psutil (pip install psutil) to log available CPU memory."
811
+ )
812
+ info["cpu_ram_mb"] = "N/A"
813
+
814
+ info["use_gpu"] = self.args.is_gpu
815
+ if self.args.is_gpu:
816
+ info["num_gpus"] = 1 # TODO(PVP) Currently only single GPU is supported
817
+ if is_py3nvml_available():
818
+ nvml.nvmlInit()
819
+ handle = nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx)
820
+ info["gpu"] = nvml.nvmlDeviceGetName(handle)
821
+ info["gpu_ram_mb"] = bytes_to_mega_bytes(nvml.nvmlDeviceGetMemoryInfo(handle).total)
822
+ info["gpu_power_watts"] = nvml.nvmlDeviceGetPowerManagementLimit(handle) / 1000
823
+ info["gpu_performance_state"] = nvml.nvmlDeviceGetPerformanceState(handle)
824
+ nvml.nvmlShutdown()
825
+ else:
826
+ logger.warning(
827
+ "py3nvml not installed, we won't log GPU memory usage. "
828
+ "Install py3nvml (pip install py3nvml) to log information about GPU."
829
+ )
830
+ info["gpu"] = "N/A"
831
+ info["gpu_ram_mb"] = "N/A"
832
+ info["gpu_power_watts"] = "N/A"
833
+ info["gpu_performance_state"] = "N/A"
834
+
835
+ info["use_tpu"] = self.args.is_tpu
836
+ # TODO(PVP): See if we can add more information about TPU
837
+ # see: https://github.com/pytorch/xla/issues/2180
838
+
839
+ self._environment_info = info
840
+ return self._environment_info
841
+
842
+ def print_results(self, result_dict, type_label):
843
+ self.print_fn(80 * "-")
844
+ self.print_fn(
845
+ "Model Name".center(30) + "Batch Size".center(15) + "Seq Length".center(15) + type_label.center(15)
846
+ )
847
+ self.print_fn(80 * "-")
848
+ for model_name in self.args.model_names:
849
+ for batch_size in result_dict[model_name]["bs"]:
850
+ for sequence_length in result_dict[model_name]["ss"]:
851
+ result = result_dict[model_name]["result"][batch_size][sequence_length]
852
+ if isinstance(result, float):
853
+ result = round(1000 * result) / 1000
854
+ result = "< 0.001" if result == 0.0 else str(result)
855
+ else:
856
+ result = str(result)
857
+ self.print_fn(
858
+ model_name[:30].center(30) + str(batch_size).center(15),
859
+ str(sequence_length).center(15),
860
+ result.center(15),
861
+ )
862
+ self.print_fn(80 * "-")
863
+
864
+ def print_memory_trace_statistics(self, summary: MemorySummary):
865
+ self.print_fn(
866
+ "\nLine by line memory consumption:\n"
867
+ + "\n".join(
868
+ f"{state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
869
+ for state in summary.sequential
870
+ )
871
+ )
872
+ self.print_fn(
873
+ "\nLines with top memory consumption:\n"
874
+ + "\n".join(
875
+ f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
876
+ for state in summary.cumulative[:6]
877
+ )
878
+ )
879
+ self.print_fn(
880
+ "\nLines with lowest memory consumption:\n"
881
+ + "\n".join(
882
+ f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
883
+ for state in summary.cumulative[-6:]
884
+ )
885
+ )
886
+ self.print_fn(f"\nTotal memory increase: {summary.total}")
887
+
888
+ def save_to_csv(self, result_dict, filename):
889
+ if not self.args.save_to_csv:
890
+ return
891
+ self.print_fn("Saving results to csv.")
892
+ with open(filename, mode="w") as csv_file:
893
+ if len(self.args.model_names) <= 0:
894
+ raise ValueError(f"At least 1 model should be defined, but got {self.model_names}")
895
+
896
+ fieldnames = ["model", "batch_size", "sequence_length"]
897
+ writer = csv.DictWriter(csv_file, fieldnames=fieldnames + ["result"])
898
+ writer.writeheader()
899
+
900
+ for model_name in self.args.model_names:
901
+ result_dict_model = result_dict[model_name]["result"]
902
+ for bs in result_dict_model:
903
+ for ss in result_dict_model[bs]:
904
+ result_model = result_dict_model[bs][ss]
905
+ writer.writerow(
906
+ {
907
+ "model": model_name,
908
+ "batch_size": bs,
909
+ "sequence_length": ss,
910
+ "result": ("{}" if not isinstance(result_model, float) else "{:.4f}").format(
911
+ result_model
912
+ ),
913
+ }
914
+ )
transformers_4_35_0/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import ArgumentParser
17
+
18
+
19
+ class BaseTransformersCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: ArgumentParser):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
transformers_4_35_0/commands/add_new_model.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ import shutil
18
+ import warnings
19
+ from argparse import ArgumentParser, Namespace
20
+ from pathlib import Path
21
+ from typing import List
22
+
23
+ from ..utils import logging
24
+ from . import BaseTransformersCLICommand
25
+
26
+
27
+ try:
28
+ from cookiecutter.main import cookiecutter
29
+
30
+ _has_cookiecutter = True
31
+ except ImportError:
32
+ _has_cookiecutter = False
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ def add_new_model_command_factory(args: Namespace):
38
+ return AddNewModelCommand(args.testing, args.testing_file, path=args.path)
39
+
40
+
41
+ class AddNewModelCommand(BaseTransformersCLICommand):
42
+ @staticmethod
43
+ def register_subcommand(parser: ArgumentParser):
44
+ add_new_model_parser = parser.add_parser("add-new-model")
45
+ add_new_model_parser.add_argument("--testing", action="store_true", help="If in testing mode.")
46
+ add_new_model_parser.add_argument("--testing_file", type=str, help="Configuration file on which to run.")
47
+ add_new_model_parser.add_argument(
48
+ "--path", type=str, help="Path to cookiecutter. Should only be used for testing purposes."
49
+ )
50
+ add_new_model_parser.set_defaults(func=add_new_model_command_factory)
51
+
52
+ def __init__(self, testing: bool, testing_file: str, path=None, *args):
53
+ self._testing = testing
54
+ self._testing_file = testing_file
55
+ self._path = path
56
+
57
+ def run(self):
58
+ warnings.warn(
59
+ "The command `transformers-cli add-new-model` is deprecated and will be removed in v5 of Transformers. "
60
+ "It is not actively maintained anymore, so might give a result that won't pass all tests and quality "
61
+ "checks, you should use `transformers-cli add-new-model-like` instead."
62
+ )
63
+ if not _has_cookiecutter:
64
+ raise ImportError(
65
+ "Model creation dependencies are required to use the `add_new_model` command. Install them by running "
66
+ "the following at the root of your `transformers` clone:\n\n\t$ pip install -e .[modelcreation]\n"
67
+ )
68
+ # Ensure that there is no other `cookiecutter-template-xxx` directory in the current working directory
69
+ directories = [directory for directory in os.listdir() if "cookiecutter-template-" == directory[:22]]
70
+ if len(directories) > 0:
71
+ raise ValueError(
72
+ "Several directories starting with `cookiecutter-template-` in current working directory. "
73
+ "Please clean your directory by removing all folders starting with `cookiecutter-template-` or "
74
+ "change your working directory."
75
+ )
76
+
77
+ path_to_transformer_root = (
78
+ Path(__file__).parent.parent.parent.parent if self._path is None else Path(self._path).parent.parent
79
+ )
80
+ path_to_cookiecutter = path_to_transformer_root / "templates" / "adding_a_new_model"
81
+
82
+ # Execute cookiecutter
83
+ if not self._testing:
84
+ cookiecutter(str(path_to_cookiecutter))
85
+ else:
86
+ with open(self._testing_file, "r") as configuration_file:
87
+ testing_configuration = json.load(configuration_file)
88
+
89
+ cookiecutter(
90
+ str(path_to_cookiecutter if self._path is None else self._path),
91
+ no_input=True,
92
+ extra_context=testing_configuration,
93
+ )
94
+
95
+ directory = [directory for directory in os.listdir() if "cookiecutter-template-" in directory[:22]][0]
96
+
97
+ # Retrieve configuration
98
+ with open(directory + "/configuration.json", "r") as configuration_file:
99
+ configuration = json.load(configuration_file)
100
+
101
+ lowercase_model_name = configuration["lowercase_modelname"]
102
+ generate_tensorflow_pytorch_and_flax = configuration["generate_tensorflow_pytorch_and_flax"]
103
+ os.remove(f"{directory}/configuration.json")
104
+
105
+ output_pytorch = "PyTorch" in generate_tensorflow_pytorch_and_flax
106
+ output_tensorflow = "TensorFlow" in generate_tensorflow_pytorch_and_flax
107
+ output_flax = "Flax" in generate_tensorflow_pytorch_and_flax
108
+
109
+ model_dir = f"{path_to_transformer_root}/src/transformers/models/{lowercase_model_name}"
110
+ os.makedirs(model_dir, exist_ok=True)
111
+ os.makedirs(f"{path_to_transformer_root}/tests/models/{lowercase_model_name}", exist_ok=True)
112
+
113
+ # Tests require submodules as they have parent imports
114
+ with open(f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/__init__.py", "w"):
115
+ pass
116
+
117
+ shutil.move(
118
+ f"{directory}/__init__.py",
119
+ f"{model_dir}/__init__.py",
120
+ )
121
+ shutil.move(
122
+ f"{directory}/configuration_{lowercase_model_name}.py",
123
+ f"{model_dir}/configuration_{lowercase_model_name}.py",
124
+ )
125
+
126
+ def remove_copy_lines(path):
127
+ with open(path, "r") as f:
128
+ lines = f.readlines()
129
+ with open(path, "w") as f:
130
+ for line in lines:
131
+ if "# Copied from transformers." not in line:
132
+ f.write(line)
133
+
134
+ if output_pytorch:
135
+ if not self._testing:
136
+ remove_copy_lines(f"{directory}/modeling_{lowercase_model_name}.py")
137
+
138
+ shutil.move(
139
+ f"{directory}/modeling_{lowercase_model_name}.py",
140
+ f"{model_dir}/modeling_{lowercase_model_name}.py",
141
+ )
142
+
143
+ shutil.move(
144
+ f"{directory}/test_modeling_{lowercase_model_name}.py",
145
+ f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_{lowercase_model_name}.py",
146
+ )
147
+ else:
148
+ os.remove(f"{directory}/modeling_{lowercase_model_name}.py")
149
+ os.remove(f"{directory}/test_modeling_{lowercase_model_name}.py")
150
+
151
+ if output_tensorflow:
152
+ if not self._testing:
153
+ remove_copy_lines(f"{directory}/modeling_tf_{lowercase_model_name}.py")
154
+
155
+ shutil.move(
156
+ f"{directory}/modeling_tf_{lowercase_model_name}.py",
157
+ f"{model_dir}/modeling_tf_{lowercase_model_name}.py",
158
+ )
159
+
160
+ shutil.move(
161
+ f"{directory}/test_modeling_tf_{lowercase_model_name}.py",
162
+ f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_tf_{lowercase_model_name}.py",
163
+ )
164
+ else:
165
+ os.remove(f"{directory}/modeling_tf_{lowercase_model_name}.py")
166
+ os.remove(f"{directory}/test_modeling_tf_{lowercase_model_name}.py")
167
+
168
+ if output_flax:
169
+ if not self._testing:
170
+ remove_copy_lines(f"{directory}/modeling_flax_{lowercase_model_name}.py")
171
+
172
+ shutil.move(
173
+ f"{directory}/modeling_flax_{lowercase_model_name}.py",
174
+ f"{model_dir}/modeling_flax_{lowercase_model_name}.py",
175
+ )
176
+
177
+ shutil.move(
178
+ f"{directory}/test_modeling_flax_{lowercase_model_name}.py",
179
+ f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_flax_{lowercase_model_name}.py",
180
+ )
181
+ else:
182
+ os.remove(f"{directory}/modeling_flax_{lowercase_model_name}.py")
183
+ os.remove(f"{directory}/test_modeling_flax_{lowercase_model_name}.py")
184
+
185
+ shutil.move(
186
+ f"{directory}/{lowercase_model_name}.md",
187
+ f"{path_to_transformer_root}/docs/source/en/model_doc/{lowercase_model_name}.md",
188
+ )
189
+
190
+ shutil.move(
191
+ f"{directory}/tokenization_{lowercase_model_name}.py",
192
+ f"{model_dir}/tokenization_{lowercase_model_name}.py",
193
+ )
194
+
195
+ shutil.move(
196
+ f"{directory}/tokenization_fast_{lowercase_model_name}.py",
197
+ f"{model_dir}/tokenization_{lowercase_model_name}_fast.py",
198
+ )
199
+
200
+ from os import fdopen, remove
201
+ from shutil import copymode, move
202
+ from tempfile import mkstemp
203
+
204
+ def replace(original_file: str, line_to_copy_below: str, lines_to_copy: List[str]):
205
+ # Create temp file
206
+ fh, abs_path = mkstemp()
207
+ line_found = False
208
+ with fdopen(fh, "w") as new_file:
209
+ with open(original_file) as old_file:
210
+ for line in old_file:
211
+ new_file.write(line)
212
+ if line_to_copy_below in line:
213
+ line_found = True
214
+ for line_to_copy in lines_to_copy:
215
+ new_file.write(line_to_copy)
216
+
217
+ if not line_found:
218
+ raise ValueError(f"Line {line_to_copy_below} was not found in file.")
219
+
220
+ # Copy the file permissions from the old file to the new file
221
+ copymode(original_file, abs_path)
222
+ # Remove original file
223
+ remove(original_file)
224
+ # Move new file
225
+ move(abs_path, original_file)
226
+
227
+ def skip_units(line):
228
+ return (
229
+ ("generating PyTorch" in line and not output_pytorch)
230
+ or ("generating TensorFlow" in line and not output_tensorflow)
231
+ or ("generating Flax" in line and not output_flax)
232
+ )
233
+
234
+ def replace_in_files(path_to_datafile):
235
+ with open(path_to_datafile) as datafile:
236
+ lines_to_copy = []
237
+ skip_file = False
238
+ skip_snippet = False
239
+ for line in datafile:
240
+ if "# To replace in: " in line and "##" not in line:
241
+ file_to_replace_in = line.split('"')[1]
242
+ skip_file = skip_units(line)
243
+ elif "# Below: " in line and "##" not in line:
244
+ line_to_copy_below = line.split('"')[1]
245
+ skip_snippet = skip_units(line)
246
+ elif "# End." in line and "##" not in line:
247
+ if not skip_file and not skip_snippet:
248
+ replace(file_to_replace_in, line_to_copy_below, lines_to_copy)
249
+
250
+ lines_to_copy = []
251
+ elif "# Replace with" in line and "##" not in line:
252
+ lines_to_copy = []
253
+ elif "##" not in line:
254
+ lines_to_copy.append(line)
255
+
256
+ remove(path_to_datafile)
257
+
258
+ replace_in_files(f"{directory}/to_replace_{lowercase_model_name}.py")
259
+ os.rmdir(directory)
transformers_4_35_0/commands/add_new_model_like.py ADDED
@@ -0,0 +1,1763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import difflib
16
+ import json
17
+ import os
18
+ import re
19
+ from argparse import ArgumentParser, Namespace
20
+ from dataclasses import dataclass
21
+ from datetime import date
22
+ from itertools import chain
23
+ from pathlib import Path
24
+ from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union
25
+
26
+ import yaml
27
+
28
+ from ..models import auto as auto_module
29
+ from ..models.auto.configuration_auto import model_type_to_module_name
30
+ from ..utils import is_flax_available, is_tf_available, is_torch_available, logging
31
+ from . import BaseTransformersCLICommand
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ CURRENT_YEAR = date.today().year
38
+ TRANSFORMERS_PATH = Path(__file__).parent.parent
39
+ REPO_PATH = TRANSFORMERS_PATH.parent.parent
40
+
41
+
42
+ @dataclass
43
+ class ModelPatterns:
44
+ """
45
+ Holds the basic information about a new model for the add-new-model-like command.
46
+
47
+ Args:
48
+ model_name (`str`): The model name.
49
+ checkpoint (`str`): The checkpoint to use for doc examples.
50
+ model_type (`str`, *optional*):
51
+ The model type, the identifier used internally in the library like `bert` or `xlm-roberta`. Will default to
52
+ `model_name` lowercased with spaces replaced with minuses (-).
53
+ model_lower_cased (`str`, *optional*):
54
+ The lowercased version of the model name, to use for the module name or function names. Will default to
55
+ `model_name` lowercased with spaces and minuses replaced with underscores.
56
+ model_camel_cased (`str`, *optional*):
57
+ The camel-cased version of the model name, to use for the class names. Will default to `model_name`
58
+ camel-cased (with spaces and minuses both considered as word separators.
59
+ model_upper_cased (`str`, *optional*):
60
+ The uppercased version of the model name, to use for the constant names. Will default to `model_name`
61
+ uppercased with spaces and minuses replaced with underscores.
62
+ config_class (`str`, *optional*):
63
+ The tokenizer class associated with this model. Will default to `"{model_camel_cased}Config"`.
64
+ tokenizer_class (`str`, *optional*):
65
+ The tokenizer class associated with this model (leave to `None` for models that don't use a tokenizer).
66
+ image_processor_class (`str`, *optional*):
67
+ The image processor class associated with this model (leave to `None` for models that don't use an image
68
+ processor).
69
+ feature_extractor_class (`str`, *optional*):
70
+ The feature extractor class associated with this model (leave to `None` for models that don't use a feature
71
+ extractor).
72
+ processor_class (`str`, *optional*):
73
+ The processor class associated with this model (leave to `None` for models that don't use a processor).
74
+ """
75
+
76
+ model_name: str
77
+ checkpoint: str
78
+ model_type: Optional[str] = None
79
+ model_lower_cased: Optional[str] = None
80
+ model_camel_cased: Optional[str] = None
81
+ model_upper_cased: Optional[str] = None
82
+ config_class: Optional[str] = None
83
+ tokenizer_class: Optional[str] = None
84
+ image_processor_class: Optional[str] = None
85
+ feature_extractor_class: Optional[str] = None
86
+ processor_class: Optional[str] = None
87
+
88
+ def __post_init__(self):
89
+ if self.model_type is None:
90
+ self.model_type = self.model_name.lower().replace(" ", "-")
91
+ if self.model_lower_cased is None:
92
+ self.model_lower_cased = self.model_name.lower().replace(" ", "_").replace("-", "_")
93
+ if self.model_camel_cased is None:
94
+ # Split the model name on - and space
95
+ words = self.model_name.split(" ")
96
+ words = list(chain(*[w.split("-") for w in words]))
97
+ # Make sure each word is capitalized
98
+ words = [w[0].upper() + w[1:] for w in words]
99
+ self.model_camel_cased = "".join(words)
100
+ if self.model_upper_cased is None:
101
+ self.model_upper_cased = self.model_name.upper().replace(" ", "_").replace("-", "_")
102
+ if self.config_class is None:
103
+ self.config_class = f"{self.model_camel_cased}Config"
104
+
105
+
106
+ ATTRIBUTE_TO_PLACEHOLDER = {
107
+ "config_class": "[CONFIG_CLASS]",
108
+ "tokenizer_class": "[TOKENIZER_CLASS]",
109
+ "image_processor_class": "[IMAGE_PROCESSOR_CLASS]",
110
+ "feature_extractor_class": "[FEATURE_EXTRACTOR_CLASS]",
111
+ "processor_class": "[PROCESSOR_CLASS]",
112
+ "checkpoint": "[CHECKPOINT]",
113
+ "model_type": "[MODEL_TYPE]",
114
+ "model_upper_cased": "[MODEL_UPPER_CASED]",
115
+ "model_camel_cased": "[MODEL_CAMELCASED]",
116
+ "model_lower_cased": "[MODEL_LOWER_CASED]",
117
+ "model_name": "[MODEL_NAME]",
118
+ }
119
+
120
+
121
+ def is_empty_line(line: str) -> bool:
122
+ """
123
+ Determines whether a line is empty or not.
124
+ """
125
+ return len(line) == 0 or line.isspace()
126
+
127
+
128
+ def find_indent(line: str) -> int:
129
+ """
130
+ Returns the number of spaces that start a line indent.
131
+ """
132
+ search = re.search(r"^(\s*)(?:\S|$)", line)
133
+ if search is None:
134
+ return 0
135
+ return len(search.groups()[0])
136
+
137
+
138
+ def parse_module_content(content: str) -> List[str]:
139
+ """
140
+ Parse the content of a module in the list of objects it defines.
141
+
142
+ Args:
143
+ content (`str`): The content to parse
144
+
145
+ Returns:
146
+ `List[str]`: The list of objects defined in the module.
147
+ """
148
+ objects = []
149
+ current_object = []
150
+ lines = content.split("\n")
151
+ # Doc-styler takes everything between two triple quotes in docstrings, so we need a fake """ here to go with this.
152
+ end_markers = [")", "]", "}", '"""']
153
+
154
+ for line in lines:
155
+ # End of an object
156
+ is_valid_object = len(current_object) > 0
157
+ if is_valid_object and len(current_object) == 1:
158
+ is_valid_object = not current_object[0].startswith("# Copied from")
159
+ if not is_empty_line(line) and find_indent(line) == 0 and is_valid_object:
160
+ # Closing parts should be included in current object
161
+ if line in end_markers:
162
+ current_object.append(line)
163
+ objects.append("\n".join(current_object))
164
+ current_object = []
165
+ else:
166
+ objects.append("\n".join(current_object))
167
+ current_object = [line]
168
+ else:
169
+ current_object.append(line)
170
+
171
+ # Add last object
172
+ if len(current_object) > 0:
173
+ objects.append("\n".join(current_object))
174
+
175
+ return objects
176
+
177
+
178
+ def extract_block(content: str, indent_level: int = 0) -> str:
179
+ """Return the first block in `content` with the indent level `indent_level`.
180
+
181
+ The first line in `content` should be indented at `indent_level` level, otherwise an error will be thrown.
182
+
183
+ This method will immediately stop the search when a (non-empty) line with indent level less than `indent_level` is
184
+ encountered.
185
+
186
+ Args:
187
+ content (`str`): The content to parse
188
+ indent_level (`int`, *optional*, default to 0): The indent level of the blocks to search for
189
+
190
+ Returns:
191
+ `str`: The first block in `content` with the indent level `indent_level`.
192
+ """
193
+ current_object = []
194
+ lines = content.split("\n")
195
+ # Doc-styler takes everything between two triple quotes in docstrings, so we need a fake """ here to go with this.
196
+ end_markers = [")", "]", "}", '"""']
197
+
198
+ for idx, line in enumerate(lines):
199
+ if idx == 0 and indent_level > 0 and not is_empty_line(line) and find_indent(line) != indent_level:
200
+ raise ValueError(
201
+ f"When `indent_level > 0`, the first line in `content` should have indent level {indent_level}. Got "
202
+ f"{find_indent(line)} instead."
203
+ )
204
+
205
+ if find_indent(line) < indent_level and not is_empty_line(line):
206
+ break
207
+
208
+ # End of an object
209
+ is_valid_object = len(current_object) > 0
210
+ if (
211
+ not is_empty_line(line)
212
+ and not line.endswith(":")
213
+ and find_indent(line) == indent_level
214
+ and is_valid_object
215
+ ):
216
+ # Closing parts should be included in current object
217
+ if line.lstrip() in end_markers:
218
+ current_object.append(line)
219
+ return "\n".join(current_object)
220
+ else:
221
+ current_object.append(line)
222
+
223
+ # Add last object
224
+ if len(current_object) > 0:
225
+ return "\n".join(current_object)
226
+
227
+
228
+ def add_content_to_text(
229
+ text: str,
230
+ content: str,
231
+ add_after: Optional[Union[str, Pattern]] = None,
232
+ add_before: Optional[Union[str, Pattern]] = None,
233
+ exact_match: bool = False,
234
+ ) -> str:
235
+ """
236
+ A utility to add some content inside a given text.
237
+
238
+ Args:
239
+ text (`str`): The text in which we want to insert some content.
240
+ content (`str`): The content to add.
241
+ add_after (`str` or `Pattern`):
242
+ The pattern to test on a line of `text`, the new content is added after the first instance matching it.
243
+ add_before (`str` or `Pattern`):
244
+ The pattern to test on a line of `text`, the new content is added before the first instance matching it.
245
+ exact_match (`bool`, *optional*, defaults to `False`):
246
+ A line is considered a match with `add_after` or `add_before` if it matches exactly when `exact_match=True`,
247
+ otherwise, if `add_after`/`add_before` is present in the line.
248
+
249
+ <Tip warning={true}>
250
+
251
+ The arguments `add_after` and `add_before` are mutually exclusive, and one exactly needs to be provided.
252
+
253
+ </Tip>
254
+
255
+ Returns:
256
+ `str`: The text with the new content added if a match was found.
257
+ """
258
+ if add_after is None and add_before is None:
259
+ raise ValueError("You need to pass either `add_after` or `add_before`")
260
+ if add_after is not None and add_before is not None:
261
+ raise ValueError("You can't pass both `add_after` or `add_before`")
262
+ pattern = add_after if add_before is None else add_before
263
+
264
+ def this_is_the_line(line):
265
+ if isinstance(pattern, Pattern):
266
+ return pattern.search(line) is not None
267
+ elif exact_match:
268
+ return pattern == line
269
+ else:
270
+ return pattern in line
271
+
272
+ new_lines = []
273
+ for line in text.split("\n"):
274
+ if this_is_the_line(line):
275
+ if add_before is not None:
276
+ new_lines.append(content)
277
+ new_lines.append(line)
278
+ if add_after is not None:
279
+ new_lines.append(content)
280
+ else:
281
+ new_lines.append(line)
282
+
283
+ return "\n".join(new_lines)
284
+
285
+
286
+ def add_content_to_file(
287
+ file_name: Union[str, os.PathLike],
288
+ content: str,
289
+ add_after: Optional[Union[str, Pattern]] = None,
290
+ add_before: Optional[Union[str, Pattern]] = None,
291
+ exact_match: bool = False,
292
+ ):
293
+ """
294
+ A utility to add some content inside a given file.
295
+
296
+ Args:
297
+ file_name (`str` or `os.PathLike`): The name of the file in which we want to insert some content.
298
+ content (`str`): The content to add.
299
+ add_after (`str` or `Pattern`):
300
+ The pattern to test on a line of `text`, the new content is added after the first instance matching it.
301
+ add_before (`str` or `Pattern`):
302
+ The pattern to test on a line of `text`, the new content is added before the first instance matching it.
303
+ exact_match (`bool`, *optional*, defaults to `False`):
304
+ A line is considered a match with `add_after` or `add_before` if it matches exactly when `exact_match=True`,
305
+ otherwise, if `add_after`/`add_before` is present in the line.
306
+
307
+ <Tip warning={true}>
308
+
309
+ The arguments `add_after` and `add_before` are mutually exclusive, and one exactly needs to be provided.
310
+
311
+ </Tip>
312
+ """
313
+ with open(file_name, "r", encoding="utf-8") as f:
314
+ old_content = f.read()
315
+
316
+ new_content = add_content_to_text(
317
+ old_content, content, add_after=add_after, add_before=add_before, exact_match=exact_match
318
+ )
319
+
320
+ with open(file_name, "w", encoding="utf-8") as f:
321
+ f.write(new_content)
322
+
323
+
324
+ def replace_model_patterns(
325
+ text: str, old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns
326
+ ) -> Tuple[str, str]:
327
+ """
328
+ Replace all patterns present in a given text.
329
+
330
+ Args:
331
+ text (`str`): The text to treat.
332
+ old_model_patterns (`ModelPatterns`): The patterns for the old model.
333
+ new_model_patterns (`ModelPatterns`): The patterns for the new model.
334
+
335
+ Returns:
336
+ `Tuple(str, str)`: A tuple of with the treated text and the replacement actually done in it.
337
+ """
338
+ # The order is crucially important as we will check and replace in that order. For instance the config probably
339
+ # contains the camel-cased named, but will be treated before.
340
+ attributes_to_check = ["config_class"]
341
+ # Add relevant preprocessing classes
342
+ for attr in ["tokenizer_class", "image_processor_class", "feature_extractor_class", "processor_class"]:
343
+ if getattr(old_model_patterns, attr) is not None and getattr(new_model_patterns, attr) is not None:
344
+ attributes_to_check.append(attr)
345
+
346
+ # Special cases for checkpoint and model_type
347
+ if old_model_patterns.checkpoint not in [old_model_patterns.model_type, old_model_patterns.model_lower_cased]:
348
+ attributes_to_check.append("checkpoint")
349
+ if old_model_patterns.model_type != old_model_patterns.model_lower_cased:
350
+ attributes_to_check.append("model_type")
351
+ else:
352
+ text = re.sub(
353
+ rf'(\s*)model_type = "{old_model_patterns.model_type}"',
354
+ r'\1model_type = "[MODEL_TYPE]"',
355
+ text,
356
+ )
357
+
358
+ # Special case when the model camel cased and upper cased names are the same for the old model (like for GPT2) but
359
+ # not the new one. We can't just do a replace in all the text and will need a special regex
360
+ if old_model_patterns.model_upper_cased == old_model_patterns.model_camel_cased:
361
+ old_model_value = old_model_patterns.model_upper_cased
362
+ if re.search(rf"{old_model_value}_[A-Z_]*[^A-Z_]", text) is not None:
363
+ text = re.sub(rf"{old_model_value}([A-Z_]*)([^a-zA-Z_])", r"[MODEL_UPPER_CASED]\1\2", text)
364
+ else:
365
+ attributes_to_check.append("model_upper_cased")
366
+
367
+ attributes_to_check.extend(["model_camel_cased", "model_lower_cased", "model_name"])
368
+
369
+ # Now let's replace every other attribute by their placeholder
370
+ for attr in attributes_to_check:
371
+ text = text.replace(getattr(old_model_patterns, attr), ATTRIBUTE_TO_PLACEHOLDER[attr])
372
+
373
+ # Finally we can replace the placeholder byt the new values.
374
+ replacements = []
375
+ for attr, placeholder in ATTRIBUTE_TO_PLACEHOLDER.items():
376
+ if placeholder in text:
377
+ replacements.append((getattr(old_model_patterns, attr), getattr(new_model_patterns, attr)))
378
+ text = text.replace(placeholder, getattr(new_model_patterns, attr))
379
+
380
+ # If we have two inconsistent replacements, we don't return anything (ex: GPT2->GPT_NEW and GPT2->GPTNew)
381
+ old_replacement_values = [old for old, new in replacements]
382
+ if len(set(old_replacement_values)) != len(old_replacement_values):
383
+ return text, ""
384
+
385
+ replacements = simplify_replacements(replacements)
386
+ replacements = [f"{old}->{new}" for old, new in replacements]
387
+ return text, ",".join(replacements)
388
+
389
+
390
+ def simplify_replacements(replacements):
391
+ """
392
+ Simplify a list of replacement patterns to make sure there are no needless ones.
393
+
394
+ For instance in the sequence "Bert->BertNew, BertConfig->BertNewConfig, bert->bert_new", the replacement
395
+ "BertConfig->BertNewConfig" is implied by "Bert->BertNew" so not needed.
396
+
397
+ Args:
398
+ replacements (`List[Tuple[str, str]]`): List of patterns (old, new)
399
+
400
+ Returns:
401
+ `List[Tuple[str, str]]`: The list of patterns simplified.
402
+ """
403
+ if len(replacements) <= 1:
404
+ # Nothing to simplify
405
+ return replacements
406
+
407
+ # Next let's sort replacements by length as a replacement can only "imply" another replacement if it's shorter.
408
+ replacements.sort(key=lambda x: len(x[0]))
409
+
410
+ idx = 0
411
+ while idx < len(replacements):
412
+ old, new = replacements[idx]
413
+ # Loop through all replacements after
414
+ j = idx + 1
415
+ while j < len(replacements):
416
+ old_2, new_2 = replacements[j]
417
+ # If the replacement is implied by the current one, we can drop it.
418
+ if old_2.replace(old, new) == new_2:
419
+ replacements.pop(j)
420
+ else:
421
+ j += 1
422
+ idx += 1
423
+
424
+ return replacements
425
+
426
+
427
+ def get_module_from_file(module_file: Union[str, os.PathLike]) -> str:
428
+ """
429
+ Returns the module name corresponding to a module file.
430
+ """
431
+ full_module_path = Path(module_file).absolute()
432
+ module_parts = full_module_path.with_suffix("").parts
433
+
434
+ # Find the first part named transformers, starting from the end.
435
+ idx = len(module_parts) - 1
436
+ while idx >= 0 and module_parts[idx] != "transformers":
437
+ idx -= 1
438
+ if idx < 0:
439
+ raise ValueError(f"{module_file} is not a transformers module.")
440
+
441
+ return ".".join(module_parts[idx:])
442
+
443
+
444
+ SPECIAL_PATTERNS = {
445
+ "_CHECKPOINT_FOR_DOC =": "checkpoint",
446
+ "_CONFIG_FOR_DOC =": "config_class",
447
+ "_TOKENIZER_FOR_DOC =": "tokenizer_class",
448
+ "_IMAGE_PROCESSOR_FOR_DOC =": "image_processor_class",
449
+ "_FEAT_EXTRACTOR_FOR_DOC =": "feature_extractor_class",
450
+ "_PROCESSOR_FOR_DOC =": "processor_class",
451
+ }
452
+
453
+
454
+ _re_class_func = re.compile(r"^(?:class|def)\s+([^\s:\(]+)\s*(?:\(|\:)", flags=re.MULTILINE)
455
+
456
+
457
+ def remove_attributes(obj, target_attr):
458
+ """Remove `target_attr` in `obj`."""
459
+ lines = obj.split(os.linesep)
460
+
461
+ target_idx = None
462
+ for idx, line in enumerate(lines):
463
+ # search for assignment
464
+ if line.lstrip().startswith(f"{target_attr} = "):
465
+ target_idx = idx
466
+ break
467
+ # search for function/method definition
468
+ elif line.lstrip().startswith(f"def {target_attr}("):
469
+ target_idx = idx
470
+ break
471
+
472
+ # target not found
473
+ if target_idx is None:
474
+ return obj
475
+
476
+ line = lines[target_idx]
477
+ indent_level = find_indent(line)
478
+ # forward pass to find the ending of the block (including empty lines)
479
+ parsed = extract_block("\n".join(lines[target_idx:]), indent_level)
480
+ num_lines = len(parsed.split("\n"))
481
+ for idx in range(num_lines):
482
+ lines[target_idx + idx] = None
483
+
484
+ # backward pass to find comments or decorator
485
+ for idx in range(target_idx - 1, -1, -1):
486
+ line = lines[idx]
487
+ if (line.lstrip().startswith("#") or line.lstrip().startswith("@")) and find_indent(line) == indent_level:
488
+ lines[idx] = None
489
+ else:
490
+ break
491
+
492
+ new_obj = os.linesep.join([x for x in lines if x is not None])
493
+
494
+ return new_obj
495
+
496
+
497
+ def duplicate_module(
498
+ module_file: Union[str, os.PathLike],
499
+ old_model_patterns: ModelPatterns,
500
+ new_model_patterns: ModelPatterns,
501
+ dest_file: Optional[str] = None,
502
+ add_copied_from: bool = True,
503
+ attrs_to_remove: List[str] = None,
504
+ ):
505
+ """
506
+ Create a new module from an existing one and adapting all function and classes names from old patterns to new ones.
507
+
508
+ Args:
509
+ module_file (`str` or `os.PathLike`): Path to the module to duplicate.
510
+ old_model_patterns (`ModelPatterns`): The patterns for the old model.
511
+ new_model_patterns (`ModelPatterns`): The patterns for the new model.
512
+ dest_file (`str` or `os.PathLike`, *optional*): Path to the new module.
513
+ add_copied_from (`bool`, *optional*, defaults to `True`):
514
+ Whether or not to add `# Copied from` statements in the duplicated module.
515
+ """
516
+ if dest_file is None:
517
+ dest_file = str(module_file).replace(
518
+ old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
519
+ )
520
+
521
+ with open(module_file, "r", encoding="utf-8") as f:
522
+ content = f.read()
523
+
524
+ content = re.sub(r"# Copyright (\d+)\s", f"# Copyright {CURRENT_YEAR} ", content)
525
+ objects = parse_module_content(content)
526
+
527
+ # Loop and treat all objects
528
+ new_objects = []
529
+ for obj in objects:
530
+ # Special cases
531
+ if "PRETRAINED_CONFIG_ARCHIVE_MAP = {" in obj:
532
+ # docstyle-ignore
533
+ obj = (
534
+ f"{new_model_patterns.model_upper_cased}_PRETRAINED_CONFIG_ARCHIVE_MAP = "
535
+ + "{"
536
+ + f"""
537
+ "{new_model_patterns.checkpoint}": "https://huggingface.co/{new_model_patterns.checkpoint}/resolve/main/config.json",
538
+ """
539
+ + "}\n"
540
+ )
541
+ new_objects.append(obj)
542
+ continue
543
+ elif "PRETRAINED_MODEL_ARCHIVE_LIST = [" in obj:
544
+ if obj.startswith("TF_"):
545
+ prefix = "TF_"
546
+ elif obj.startswith("FLAX_"):
547
+ prefix = "FLAX_"
548
+ else:
549
+ prefix = ""
550
+ # docstyle-ignore
551
+ obj = f"""{prefix}{new_model_patterns.model_upper_cased}_PRETRAINED_MODEL_ARCHIVE_LIST = [
552
+ "{new_model_patterns.checkpoint}",
553
+ # See all {new_model_patterns.model_name} models at https://huggingface.co/models?filter={new_model_patterns.model_type}
554
+ ]
555
+ """
556
+ new_objects.append(obj)
557
+ continue
558
+
559
+ special_pattern = False
560
+ for pattern, attr in SPECIAL_PATTERNS.items():
561
+ if pattern in obj:
562
+ obj = obj.replace(getattr(old_model_patterns, attr), getattr(new_model_patterns, attr))
563
+ new_objects.append(obj)
564
+ special_pattern = True
565
+ break
566
+
567
+ if special_pattern:
568
+ continue
569
+
570
+ # Regular classes functions
571
+ old_obj = obj
572
+ obj, replacement = replace_model_patterns(obj, old_model_patterns, new_model_patterns)
573
+ has_copied_from = re.search(r"^#\s+Copied from", obj, flags=re.MULTILINE) is not None
574
+ if add_copied_from and not has_copied_from and _re_class_func.search(obj) is not None and len(replacement) > 0:
575
+ # Copied from statement must be added just before the class/function definition, which may not be the
576
+ # first line because of decorators.
577
+ module_name = get_module_from_file(module_file)
578
+ old_object_name = _re_class_func.search(old_obj).groups()[0]
579
+ obj = add_content_to_text(
580
+ obj, f"# Copied from {module_name}.{old_object_name} with {replacement}", add_before=_re_class_func
581
+ )
582
+ # In all cases, we remove Copied from statement with indent on methods.
583
+ obj = re.sub("\n[ ]+# Copied from [^\n]*\n", "\n", obj)
584
+
585
+ new_objects.append(obj)
586
+
587
+ content = "\n".join(new_objects)
588
+ # Remove some attributes that we don't want to copy to the new file(s)
589
+ if attrs_to_remove is not None:
590
+ for attr in attrs_to_remove:
591
+ content = remove_attributes(content, target_attr=attr)
592
+
593
+ with open(dest_file, "w", encoding="utf-8") as f:
594
+ f.write(content)
595
+
596
+
597
+ def filter_framework_files(
598
+ files: List[Union[str, os.PathLike]], frameworks: Optional[List[str]] = None
599
+ ) -> List[Union[str, os.PathLike]]:
600
+ """
601
+ Filter a list of files to only keep the ones corresponding to a list of frameworks.
602
+
603
+ Args:
604
+ files (`List[Union[str, os.PathLike]]`): The list of files to filter.
605
+ frameworks (`List[str]`, *optional*): The list of allowed frameworks.
606
+
607
+ Returns:
608
+ `List[Union[str, os.PathLike]]`: The list of filtered files.
609
+ """
610
+ if frameworks is None:
611
+ frameworks = get_default_frameworks()
612
+
613
+ framework_to_file = {}
614
+ others = []
615
+ for f in files:
616
+ parts = Path(f).name.split("_")
617
+ if "modeling" not in parts:
618
+ others.append(f)
619
+ continue
620
+ if "tf" in parts:
621
+ framework_to_file["tf"] = f
622
+ elif "flax" in parts:
623
+ framework_to_file["flax"] = f
624
+ else:
625
+ framework_to_file["pt"] = f
626
+
627
+ return [framework_to_file[f] for f in frameworks if f in framework_to_file] + others
628
+
629
+
630
+ def get_model_files(model_type: str, frameworks: Optional[List[str]] = None) -> Dict[str, Union[Path, List[Path]]]:
631
+ """
632
+ Retrieves all the files associated to a model.
633
+
634
+ Args:
635
+ model_type (`str`): A valid model type (like "bert" or "gpt2")
636
+ frameworks (`List[str]`, *optional*):
637
+ If passed, will only keep the model files corresponding to the passed frameworks.
638
+
639
+ Returns:
640
+ `Dict[str, Union[Path, List[Path]]]`: A dictionary with the following keys:
641
+ - **doc_file** -- The documentation file for the model.
642
+ - **model_files** -- All the files in the model module.
643
+ - **test_files** -- The test files for the model.
644
+ """
645
+ module_name = model_type_to_module_name(model_type)
646
+
647
+ model_module = TRANSFORMERS_PATH / "models" / module_name
648
+ model_files = list(model_module.glob("*.py"))
649
+ model_files = filter_framework_files(model_files, frameworks=frameworks)
650
+
651
+ doc_file = REPO_PATH / "docs" / "source" / "en" / "model_doc" / f"{model_type}.md"
652
+
653
+ # Basic pattern for test files
654
+ test_files = [
655
+ f"test_modeling_{module_name}.py",
656
+ f"test_modeling_tf_{module_name}.py",
657
+ f"test_modeling_flax_{module_name}.py",
658
+ f"test_tokenization_{module_name}.py",
659
+ f"test_image_processing_{module_name}.py",
660
+ f"test_feature_extraction_{module_name}.py",
661
+ f"test_processor_{module_name}.py",
662
+ ]
663
+ test_files = filter_framework_files(test_files, frameworks=frameworks)
664
+ # Add the test directory
665
+ test_files = [REPO_PATH / "tests" / "models" / module_name / f for f in test_files]
666
+ # Filter by existing files
667
+ test_files = [f for f in test_files if f.exists()]
668
+
669
+ return {"doc_file": doc_file, "model_files": model_files, "module_name": module_name, "test_files": test_files}
670
+
671
+
672
+ _re_checkpoint_for_doc = re.compile(r"^_CHECKPOINT_FOR_DOC\s+=\s+(\S*)\s*$", flags=re.MULTILINE)
673
+
674
+
675
+ def find_base_model_checkpoint(
676
+ model_type: str, model_files: Optional[Dict[str, Union[Path, List[Path]]]] = None
677
+ ) -> str:
678
+ """
679
+ Finds the model checkpoint used in the docstrings for a given model.
680
+
681
+ Args:
682
+ model_type (`str`): A valid model type (like "bert" or "gpt2")
683
+ model_files (`Dict[str, Union[Path, List[Path]]`, *optional*):
684
+ The files associated to `model_type`. Can be passed to speed up the function, otherwise will be computed.
685
+
686
+ Returns:
687
+ `str`: The checkpoint used.
688
+ """
689
+ if model_files is None:
690
+ model_files = get_model_files(model_type)
691
+ module_files = model_files["model_files"]
692
+ for fname in module_files:
693
+ if "modeling" not in str(fname):
694
+ continue
695
+
696
+ with open(fname, "r", encoding="utf-8") as f:
697
+ content = f.read()
698
+ if _re_checkpoint_for_doc.search(content) is not None:
699
+ checkpoint = _re_checkpoint_for_doc.search(content).groups()[0]
700
+ # Remove quotes
701
+ checkpoint = checkpoint.replace('"', "")
702
+ checkpoint = checkpoint.replace("'", "")
703
+ return checkpoint
704
+
705
+ # TODO: Find some kind of fallback if there is no _CHECKPOINT_FOR_DOC in any of the modeling file.
706
+ return ""
707
+
708
+
709
+ def get_default_frameworks():
710
+ """
711
+ Returns the list of frameworks (PyTorch, TensorFlow, Flax) that are installed in the environment.
712
+ """
713
+ frameworks = []
714
+ if is_torch_available():
715
+ frameworks.append("pt")
716
+ if is_tf_available():
717
+ frameworks.append("tf")
718
+ if is_flax_available():
719
+ frameworks.append("flax")
720
+ return frameworks
721
+
722
+
723
+ _re_model_mapping = re.compile("MODEL_([A-Z_]*)MAPPING_NAMES")
724
+
725
+
726
+ def retrieve_model_classes(model_type: str, frameworks: Optional[List[str]] = None) -> Dict[str, List[str]]:
727
+ """
728
+ Retrieve the model classes associated to a given model.
729
+
730
+ Args:
731
+ model_type (`str`): A valid model type (like "bert" or "gpt2")
732
+ frameworks (`List[str]`, *optional*):
733
+ The frameworks to look for. Will default to `["pt", "tf", "flax"]`, passing a smaller list will restrict
734
+ the classes returned.
735
+
736
+ Returns:
737
+ `Dict[str, List[str]]`: A dictionary with one key per framework and the list of model classes associated to
738
+ that framework as values.
739
+ """
740
+ if frameworks is None:
741
+ frameworks = get_default_frameworks()
742
+
743
+ modules = {
744
+ "pt": auto_module.modeling_auto if is_torch_available() else None,
745
+ "tf": auto_module.modeling_tf_auto if is_tf_available() else None,
746
+ "flax": auto_module.modeling_flax_auto if is_flax_available() else None,
747
+ }
748
+
749
+ model_classes = {}
750
+ for framework in frameworks:
751
+ new_model_classes = []
752
+ if modules[framework] is None:
753
+ raise ValueError(f"You selected {framework} in the frameworks, but it is not installed.")
754
+ model_mappings = [attr for attr in dir(modules[framework]) if _re_model_mapping.search(attr) is not None]
755
+ for model_mapping_name in model_mappings:
756
+ model_mapping = getattr(modules[framework], model_mapping_name)
757
+ if model_type in model_mapping:
758
+ new_model_classes.append(model_mapping[model_type])
759
+
760
+ if len(new_model_classes) > 0:
761
+ # Remove duplicates
762
+ model_classes[framework] = list(set(new_model_classes))
763
+
764
+ return model_classes
765
+
766
+
767
+ def retrieve_info_for_model(model_type, frameworks: Optional[List[str]] = None):
768
+ """
769
+ Retrieves all the information from a given model_type.
770
+
771
+ Args:
772
+ model_type (`str`): A valid model type (like "bert" or "gpt2")
773
+ frameworks (`List[str]`, *optional*):
774
+ If passed, will only keep the info corresponding to the passed frameworks.
775
+
776
+ Returns:
777
+ `Dict`: A dictionary with the following keys:
778
+ - **frameworks** (`List[str]`): The list of frameworks that back this model type.
779
+ - **model_classes** (`Dict[str, List[str]]`): The model classes implemented for that model type.
780
+ - **model_files** (`Dict[str, Union[Path, List[Path]]]`): The files associated with that model type.
781
+ - **model_patterns** (`ModelPatterns`): The various patterns for the model.
782
+ """
783
+ if model_type not in auto_module.MODEL_NAMES_MAPPING:
784
+ raise ValueError(f"{model_type} is not a valid model type.")
785
+
786
+ model_name = auto_module.MODEL_NAMES_MAPPING[model_type]
787
+ config_class = auto_module.configuration_auto.CONFIG_MAPPING_NAMES[model_type]
788
+ archive_map = auto_module.configuration_auto.CONFIG_ARCHIVE_MAP_MAPPING_NAMES.get(model_type, None)
789
+ if model_type in auto_module.tokenization_auto.TOKENIZER_MAPPING_NAMES:
790
+ tokenizer_classes = auto_module.tokenization_auto.TOKENIZER_MAPPING_NAMES[model_type]
791
+ tokenizer_class = tokenizer_classes[0] if tokenizer_classes[0] is not None else tokenizer_classes[1]
792
+ else:
793
+ tokenizer_class = None
794
+ image_processor_class = auto_module.image_processing_auto.IMAGE_PROCESSOR_MAPPING_NAMES.get(model_type, None)
795
+ feature_extractor_class = auto_module.feature_extraction_auto.FEATURE_EXTRACTOR_MAPPING_NAMES.get(model_type, None)
796
+ processor_class = auto_module.processing_auto.PROCESSOR_MAPPING_NAMES.get(model_type, None)
797
+
798
+ model_files = get_model_files(model_type, frameworks=frameworks)
799
+ model_camel_cased = config_class.replace("Config", "")
800
+
801
+ available_frameworks = []
802
+ for fname in model_files["model_files"]:
803
+ if "modeling_tf" in str(fname):
804
+ available_frameworks.append("tf")
805
+ elif "modeling_flax" in str(fname):
806
+ available_frameworks.append("flax")
807
+ elif "modeling" in str(fname):
808
+ available_frameworks.append("pt")
809
+
810
+ if frameworks is None:
811
+ frameworks = get_default_frameworks()
812
+
813
+ frameworks = [f for f in frameworks if f in available_frameworks]
814
+
815
+ model_classes = retrieve_model_classes(model_type, frameworks=frameworks)
816
+
817
+ # Retrieve model upper-cased name from the constant name of the pretrained archive map.
818
+ if archive_map is None:
819
+ model_upper_cased = model_camel_cased.upper()
820
+ else:
821
+ parts = archive_map.split("_")
822
+ idx = 0
823
+ while idx < len(parts) and parts[idx] != "PRETRAINED":
824
+ idx += 1
825
+ if idx < len(parts):
826
+ model_upper_cased = "_".join(parts[:idx])
827
+ else:
828
+ model_upper_cased = model_camel_cased.upper()
829
+
830
+ model_patterns = ModelPatterns(
831
+ model_name,
832
+ checkpoint=find_base_model_checkpoint(model_type, model_files=model_files),
833
+ model_type=model_type,
834
+ model_camel_cased=model_camel_cased,
835
+ model_lower_cased=model_files["module_name"],
836
+ model_upper_cased=model_upper_cased,
837
+ config_class=config_class,
838
+ tokenizer_class=tokenizer_class,
839
+ image_processor_class=image_processor_class,
840
+ feature_extractor_class=feature_extractor_class,
841
+ processor_class=processor_class,
842
+ )
843
+
844
+ return {
845
+ "frameworks": frameworks,
846
+ "model_classes": model_classes,
847
+ "model_files": model_files,
848
+ "model_patterns": model_patterns,
849
+ }
850
+
851
+
852
+ def clean_frameworks_in_init(
853
+ init_file: Union[str, os.PathLike], frameworks: Optional[List[str]] = None, keep_processing: bool = True
854
+ ):
855
+ """
856
+ Removes all the import lines that don't belong to a given list of frameworks or concern tokenizers/feature
857
+ extractors/image processors/processors in an init.
858
+
859
+ Args:
860
+ init_file (`str` or `os.PathLike`): The path to the init to treat.
861
+ frameworks (`List[str]`, *optional*):
862
+ If passed, this will remove all imports that are subject to a framework not in frameworks
863
+ keep_processing (`bool`, *optional*, defaults to `True`):
864
+ Whether or not to keep the preprocessing (tokenizer, feature extractor, image processor, processor) imports
865
+ in the init.
866
+ """
867
+ if frameworks is None:
868
+ frameworks = get_default_frameworks()
869
+
870
+ names = {"pt": "torch"}
871
+ to_remove = [names.get(f, f) for f in ["pt", "tf", "flax"] if f not in frameworks]
872
+ if not keep_processing:
873
+ to_remove.extend(["sentencepiece", "tokenizers", "vision"])
874
+
875
+ if len(to_remove) == 0:
876
+ # Nothing to do
877
+ return
878
+
879
+ remove_pattern = "|".join(to_remove)
880
+ re_conditional_imports = re.compile(rf"^\s*if not is_({remove_pattern})_available\(\):\s*$")
881
+ re_try = re.compile(r"\s*try:")
882
+ re_else = re.compile(r"\s*else:")
883
+ re_is_xxx_available = re.compile(rf"is_({remove_pattern})_available")
884
+
885
+ with open(init_file, "r", encoding="utf-8") as f:
886
+ content = f.read()
887
+
888
+ lines = content.split("\n")
889
+ new_lines = []
890
+ idx = 0
891
+ while idx < len(lines):
892
+ # Conditional imports in try-except-else blocks
893
+ if (re_conditional_imports.search(lines[idx]) is not None) and (re_try.search(lines[idx - 1]) is not None):
894
+ # Remove the preceding `try:`
895
+ new_lines.pop()
896
+ idx += 1
897
+ # Iterate until `else:`
898
+ while is_empty_line(lines[idx]) or re_else.search(lines[idx]) is None:
899
+ idx += 1
900
+ idx += 1
901
+ indent = find_indent(lines[idx])
902
+ while find_indent(lines[idx]) >= indent or is_empty_line(lines[idx]):
903
+ idx += 1
904
+ # Remove the import from utils
905
+ elif re_is_xxx_available.search(lines[idx]) is not None:
906
+ line = lines[idx]
907
+ for framework in to_remove:
908
+ line = line.replace(f", is_{framework}_available", "")
909
+ line = line.replace(f"is_{framework}_available, ", "")
910
+ line = line.replace(f"is_{framework}_available,", "")
911
+ line = line.replace(f"is_{framework}_available", "")
912
+
913
+ if len(line.strip()) > 0:
914
+ new_lines.append(line)
915
+ idx += 1
916
+ # Otherwise we keep the line, except if it's a tokenizer import and we don't want to keep it.
917
+ elif keep_processing or (
918
+ re.search(r'^\s*"(tokenization|processing|feature_extraction|image_processing)', lines[idx]) is None
919
+ and re.search(r"^\s*from .(tokenization|processing|feature_extraction|image_processing)", lines[idx])
920
+ is None
921
+ ):
922
+ new_lines.append(lines[idx])
923
+ idx += 1
924
+ else:
925
+ idx += 1
926
+
927
+ with open(init_file, "w", encoding="utf-8") as f:
928
+ f.write("\n".join(new_lines))
929
+
930
+
931
+ def add_model_to_main_init(
932
+ old_model_patterns: ModelPatterns,
933
+ new_model_patterns: ModelPatterns,
934
+ frameworks: Optional[List[str]] = None,
935
+ with_processing: bool = True,
936
+ ):
937
+ """
938
+ Add a model to the main init of Transformers.
939
+
940
+ Args:
941
+ old_model_patterns (`ModelPatterns`): The patterns for the old model.
942
+ new_model_patterns (`ModelPatterns`): The patterns for the new model.
943
+ frameworks (`List[str]`, *optional*):
944
+ If specified, only the models implemented in those frameworks will be added.
945
+ with_processsing (`bool`, *optional*, defaults to `True`):
946
+ Whether the tokenizer/feature extractor/processor of the model should also be added to the init or not.
947
+ """
948
+ with open(TRANSFORMERS_PATH / "__init__.py", "r", encoding="utf-8") as f:
949
+ content = f.read()
950
+
951
+ lines = content.split("\n")
952
+ idx = 0
953
+ new_lines = []
954
+ framework = None
955
+ while idx < len(lines):
956
+ new_framework = False
957
+ if not is_empty_line(lines[idx]) and find_indent(lines[idx]) == 0:
958
+ framework = None
959
+ elif lines[idx].lstrip().startswith("if not is_torch_available"):
960
+ framework = "pt"
961
+ new_framework = True
962
+ elif lines[idx].lstrip().startswith("if not is_tf_available"):
963
+ framework = "tf"
964
+ new_framework = True
965
+ elif lines[idx].lstrip().startswith("if not is_flax_available"):
966
+ framework = "flax"
967
+ new_framework = True
968
+
969
+ if new_framework:
970
+ # For a new framework, we need to skip until the else: block to get where the imports are.
971
+ while lines[idx].strip() != "else:":
972
+ new_lines.append(lines[idx])
973
+ idx += 1
974
+
975
+ # Skip if we are in a framework not wanted.
976
+ if framework is not None and frameworks is not None and framework not in frameworks:
977
+ new_lines.append(lines[idx])
978
+ idx += 1
979
+ elif re.search(rf'models.{old_model_patterns.model_lower_cased}( |")', lines[idx]) is not None:
980
+ block = [lines[idx]]
981
+ indent = find_indent(lines[idx])
982
+ idx += 1
983
+ while find_indent(lines[idx]) > indent:
984
+ block.append(lines[idx])
985
+ idx += 1
986
+ if lines[idx].strip() in [")", "]", "],"]:
987
+ block.append(lines[idx])
988
+ idx += 1
989
+ block = "\n".join(block)
990
+ new_lines.append(block)
991
+
992
+ add_block = True
993
+ if not with_processing:
994
+ processing_classes = [
995
+ old_model_patterns.tokenizer_class,
996
+ old_model_patterns.image_processor_class,
997
+ old_model_patterns.feature_extractor_class,
998
+ old_model_patterns.processor_class,
999
+ ]
1000
+ # Only keep the ones that are not None
1001
+ processing_classes = [c for c in processing_classes if c is not None]
1002
+ for processing_class in processing_classes:
1003
+ block = block.replace(f' "{processing_class}",', "")
1004
+ block = block.replace(f', "{processing_class}"', "")
1005
+ block = block.replace(f" {processing_class},", "")
1006
+ block = block.replace(f", {processing_class}", "")
1007
+
1008
+ if processing_class in block:
1009
+ add_block = False
1010
+ if add_block:
1011
+ new_lines.append(replace_model_patterns(block, old_model_patterns, new_model_patterns)[0])
1012
+ else:
1013
+ new_lines.append(lines[idx])
1014
+ idx += 1
1015
+
1016
+ with open(TRANSFORMERS_PATH / "__init__.py", "w", encoding="utf-8") as f:
1017
+ f.write("\n".join(new_lines))
1018
+
1019
+
1020
+ def insert_tokenizer_in_auto_module(old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns):
1021
+ """
1022
+ Add a tokenizer to the relevant mappings in the auto module.
1023
+
1024
+ Args:
1025
+ old_model_patterns (`ModelPatterns`): The patterns for the old model.
1026
+ new_model_patterns (`ModelPatterns`): The patterns for the new model.
1027
+ """
1028
+ if old_model_patterns.tokenizer_class is None or new_model_patterns.tokenizer_class is None:
1029
+ return
1030
+
1031
+ with open(TRANSFORMERS_PATH / "models" / "auto" / "tokenization_auto.py", "r", encoding="utf-8") as f:
1032
+ content = f.read()
1033
+
1034
+ lines = content.split("\n")
1035
+ idx = 0
1036
+ # First we get to the TOKENIZER_MAPPING_NAMES block.
1037
+ while not lines[idx].startswith(" TOKENIZER_MAPPING_NAMES = OrderedDict("):
1038
+ idx += 1
1039
+ idx += 1
1040
+
1041
+ # That block will end at this prompt:
1042
+ while not lines[idx].startswith("TOKENIZER_MAPPING = _LazyAutoMapping"):
1043
+ # Either all the tokenizer block is defined on one line, in which case, it ends with "),"
1044
+ if lines[idx].endswith(","):
1045
+ block = lines[idx]
1046
+ # Otherwise it takes several lines until we get to a "),"
1047
+ else:
1048
+ block = []
1049
+ while not lines[idx].startswith(" ),"):
1050
+ block.append(lines[idx])
1051
+ idx += 1
1052
+ block = "\n".join(block)
1053
+ idx += 1
1054
+
1055
+ # If we find the model type and tokenizer class in that block, we have the old model tokenizer block
1056
+ if f'"{old_model_patterns.model_type}"' in block and old_model_patterns.tokenizer_class in block:
1057
+ break
1058
+
1059
+ new_block = block.replace(old_model_patterns.model_type, new_model_patterns.model_type)
1060
+ new_block = new_block.replace(old_model_patterns.tokenizer_class, new_model_patterns.tokenizer_class)
1061
+
1062
+ new_lines = lines[:idx] + [new_block] + lines[idx:]
1063
+ with open(TRANSFORMERS_PATH / "models" / "auto" / "tokenization_auto.py", "w", encoding="utf-8") as f:
1064
+ f.write("\n".join(new_lines))
1065
+
1066
+
1067
+ AUTO_CLASSES_PATTERNS = {
1068
+ "configuration_auto.py": [
1069
+ ' ("{model_type}", "{model_name}"),',
1070
+ ' ("{model_type}", "{config_class}"),',
1071
+ ' ("{model_type}", "{pretrained_archive_map}"),',
1072
+ ],
1073
+ "feature_extraction_auto.py": [' ("{model_type}", "{feature_extractor_class}"),'],
1074
+ "image_processing_auto.py": [' ("{model_type}", "{image_processor_class}"),'],
1075
+ "modeling_auto.py": [' ("{model_type}", "{any_pt_class}"),'],
1076
+ "modeling_tf_auto.py": [' ("{model_type}", "{any_tf_class}"),'],
1077
+ "modeling_flax_auto.py": [' ("{model_type}", "{any_flax_class}"),'],
1078
+ "processing_auto.py": [' ("{model_type}", "{processor_class}"),'],
1079
+ }
1080
+
1081
+
1082
+ def add_model_to_auto_classes(
1083
+ old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns, model_classes: Dict[str, List[str]]
1084
+ ):
1085
+ """
1086
+ Add a model to the relevant mappings in the auto module.
1087
+
1088
+ Args:
1089
+ old_model_patterns (`ModelPatterns`): The patterns for the old model.
1090
+ new_model_patterns (`ModelPatterns`): The patterns for the new model.
1091
+ model_classes (`Dict[str, List[str]]`): A dictionary framework to list of model classes implemented.
1092
+ """
1093
+ for filename in AUTO_CLASSES_PATTERNS:
1094
+ # Extend patterns with all model classes if necessary
1095
+ new_patterns = []
1096
+ for pattern in AUTO_CLASSES_PATTERNS[filename]:
1097
+ if re.search("any_([a-z]*)_class", pattern) is not None:
1098
+ framework = re.search("any_([a-z]*)_class", pattern).groups()[0]
1099
+ if framework in model_classes:
1100
+ new_patterns.extend(
1101
+ [
1102
+ pattern.replace("{" + f"any_{framework}_class" + "}", cls)
1103
+ for cls in model_classes[framework]
1104
+ ]
1105
+ )
1106
+ elif "{config_class}" in pattern:
1107
+ new_patterns.append(pattern.replace("{config_class}", old_model_patterns.config_class))
1108
+ elif "{image_processor_class}" in pattern:
1109
+ if (
1110
+ old_model_patterns.image_processor_class is not None
1111
+ and new_model_patterns.image_processor_class is not None
1112
+ ):
1113
+ new_patterns.append(
1114
+ pattern.replace("{image_processor_class}", old_model_patterns.image_processor_class)
1115
+ )
1116
+ elif "{feature_extractor_class}" in pattern:
1117
+ if (
1118
+ old_model_patterns.feature_extractor_class is not None
1119
+ and new_model_patterns.feature_extractor_class is not None
1120
+ ):
1121
+ new_patterns.append(
1122
+ pattern.replace("{feature_extractor_class}", old_model_patterns.feature_extractor_class)
1123
+ )
1124
+ elif "{processor_class}" in pattern:
1125
+ if old_model_patterns.processor_class is not None and new_model_patterns.processor_class is not None:
1126
+ new_patterns.append(pattern.replace("{processor_class}", old_model_patterns.processor_class))
1127
+ else:
1128
+ new_patterns.append(pattern)
1129
+
1130
+ # Loop through all patterns.
1131
+ for pattern in new_patterns:
1132
+ full_name = TRANSFORMERS_PATH / "models" / "auto" / filename
1133
+ old_model_line = pattern
1134
+ new_model_line = pattern
1135
+ for attr in ["model_type", "model_name"]:
1136
+ old_model_line = old_model_line.replace("{" + attr + "}", getattr(old_model_patterns, attr))
1137
+ new_model_line = new_model_line.replace("{" + attr + "}", getattr(new_model_patterns, attr))
1138
+ if "pretrained_archive_map" in pattern:
1139
+ old_model_line = old_model_line.replace(
1140
+ "{pretrained_archive_map}", f"{old_model_patterns.model_upper_cased}_PRETRAINED_CONFIG_ARCHIVE_MAP"
1141
+ )
1142
+ new_model_line = new_model_line.replace(
1143
+ "{pretrained_archive_map}", f"{new_model_patterns.model_upper_cased}_PRETRAINED_CONFIG_ARCHIVE_MAP"
1144
+ )
1145
+
1146
+ new_model_line = new_model_line.replace(
1147
+ old_model_patterns.model_camel_cased, new_model_patterns.model_camel_cased
1148
+ )
1149
+
1150
+ add_content_to_file(full_name, new_model_line, add_after=old_model_line)
1151
+
1152
+ # Tokenizers require special handling
1153
+ insert_tokenizer_in_auto_module(old_model_patterns, new_model_patterns)
1154
+
1155
+
1156
+ DOC_OVERVIEW_TEMPLATE = """## Overview
1157
+
1158
+ The {model_name} model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
1159
+ <INSERT SHORT SUMMARY HERE>
1160
+
1161
+ The abstract from the paper is the following:
1162
+
1163
+ *<INSERT PAPER ABSTRACT HERE>*
1164
+
1165
+ Tips:
1166
+
1167
+ <INSERT TIPS ABOUT MODEL HERE>
1168
+
1169
+ This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
1170
+ The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
1171
+
1172
+ """
1173
+
1174
+
1175
+ def duplicate_doc_file(
1176
+ doc_file: Union[str, os.PathLike],
1177
+ old_model_patterns: ModelPatterns,
1178
+ new_model_patterns: ModelPatterns,
1179
+ dest_file: Optional[Union[str, os.PathLike]] = None,
1180
+ frameworks: Optional[List[str]] = None,
1181
+ ):
1182
+ """
1183
+ Duplicate a documentation file and adapts it for a new model.
1184
+
1185
+ Args:
1186
+ module_file (`str` or `os.PathLike`): Path to the doc file to duplicate.
1187
+ old_model_patterns (`ModelPatterns`): The patterns for the old model.
1188
+ new_model_patterns (`ModelPatterns`): The patterns for the new model.
1189
+ dest_file (`str` or `os.PathLike`, *optional*): Path to the new doc file.
1190
+ Will default to the a file named `{new_model_patterns.model_type}.md` in the same folder as `module_file`.
1191
+ frameworks (`List[str]`, *optional*):
1192
+ If passed, will only keep the model classes corresponding to this list of frameworks in the new doc file.
1193
+ """
1194
+ with open(doc_file, "r", encoding="utf-8") as f:
1195
+ content = f.read()
1196
+
1197
+ content = re.sub(r"<!--\s*Copyright (\d+)\s", f"<!--Copyright {CURRENT_YEAR} ", content)
1198
+ if frameworks is None:
1199
+ frameworks = get_default_frameworks()
1200
+ if dest_file is None:
1201
+ dest_file = Path(doc_file).parent / f"{new_model_patterns.model_type}.md"
1202
+
1203
+ # Parse the doc file in blocks. One block per section/header
1204
+ lines = content.split("\n")
1205
+ blocks = []
1206
+ current_block = []
1207
+
1208
+ for line in lines:
1209
+ if line.startswith("#"):
1210
+ blocks.append("\n".join(current_block))
1211
+ current_block = [line]
1212
+ else:
1213
+ current_block.append(line)
1214
+ blocks.append("\n".join(current_block))
1215
+
1216
+ new_blocks = []
1217
+ in_classes = False
1218
+ for block in blocks:
1219
+ # Copyright
1220
+ if not block.startswith("#"):
1221
+ new_blocks.append(block)
1222
+ # Main title
1223
+ elif re.search(r"^#\s+\S+", block) is not None:
1224
+ new_blocks.append(f"# {new_model_patterns.model_name}\n")
1225
+ # The config starts the part of the doc with the classes.
1226
+ elif not in_classes and old_model_patterns.config_class in block.split("\n")[0]:
1227
+ in_classes = True
1228
+ new_blocks.append(DOC_OVERVIEW_TEMPLATE.format(model_name=new_model_patterns.model_name))
1229
+ new_block, _ = replace_model_patterns(block, old_model_patterns, new_model_patterns)
1230
+ new_blocks.append(new_block)
1231
+ # In classes
1232
+ elif in_classes:
1233
+ in_classes = True
1234
+ block_title = block.split("\n")[0]
1235
+ block_class = re.search(r"^#+\s+(\S.*)$", block_title).groups()[0]
1236
+ new_block, _ = replace_model_patterns(block, old_model_patterns, new_model_patterns)
1237
+
1238
+ if "Tokenizer" in block_class:
1239
+ # We only add the tokenizer if necessary
1240
+ if old_model_patterns.tokenizer_class != new_model_patterns.tokenizer_class:
1241
+ new_blocks.append(new_block)
1242
+ elif "ImageProcessor" in block_class:
1243
+ # We only add the image processor if necessary
1244
+ if old_model_patterns.image_processor_class != new_model_patterns.image_processor_class:
1245
+ new_blocks.append(new_block)
1246
+ elif "FeatureExtractor" in block_class:
1247
+ # We only add the feature extractor if necessary
1248
+ if old_model_patterns.feature_extractor_class != new_model_patterns.feature_extractor_class:
1249
+ new_blocks.append(new_block)
1250
+ elif "Processor" in block_class:
1251
+ # We only add the processor if necessary
1252
+ if old_model_patterns.processor_class != new_model_patterns.processor_class:
1253
+ new_blocks.append(new_block)
1254
+ elif block_class.startswith("Flax"):
1255
+ # We only add Flax models if in the selected frameworks
1256
+ if "flax" in frameworks:
1257
+ new_blocks.append(new_block)
1258
+ elif block_class.startswith("TF"):
1259
+ # We only add TF models if in the selected frameworks
1260
+ if "tf" in frameworks:
1261
+ new_blocks.append(new_block)
1262
+ elif len(block_class.split(" ")) == 1:
1263
+ # We only add PyTorch models if in the selected frameworks
1264
+ if "pt" in frameworks:
1265
+ new_blocks.append(new_block)
1266
+ else:
1267
+ new_blocks.append(new_block)
1268
+
1269
+ with open(dest_file, "w", encoding="utf-8") as f:
1270
+ f.write("\n".join(new_blocks))
1271
+
1272
+
1273
+ def insert_model_in_doc_toc(old_model_patterns, new_model_patterns):
1274
+ """
1275
+ Insert the new model in the doc TOC, in the same section as the old model.
1276
+
1277
+ Args:
1278
+ old_model_patterns (`ModelPatterns`): The patterns for the old model.
1279
+ new_model_patterns (`ModelPatterns`): The patterns for the new model.
1280
+ """
1281
+ toc_file = REPO_PATH / "docs" / "source" / "en" / "_toctree.yml"
1282
+ with open(toc_file, "r", encoding="utf8") as f:
1283
+ content = yaml.safe_load(f)
1284
+
1285
+ # Get to the model API doc
1286
+ api_idx = 0
1287
+ while content[api_idx]["title"] != "API":
1288
+ api_idx += 1
1289
+ api_doc = content[api_idx]["sections"]
1290
+
1291
+ model_idx = 0
1292
+ while api_doc[model_idx]["title"] != "Models":
1293
+ model_idx += 1
1294
+ model_doc = api_doc[model_idx]["sections"]
1295
+
1296
+ # Find the base model in the Toc
1297
+ old_model_type = old_model_patterns.model_type
1298
+ section_idx = 0
1299
+ while section_idx < len(model_doc):
1300
+ sections = [entry["local"] for entry in model_doc[section_idx]["sections"]]
1301
+ if f"model_doc/{old_model_type}" in sections:
1302
+ break
1303
+
1304
+ section_idx += 1
1305
+
1306
+ if section_idx == len(model_doc):
1307
+ old_model = old_model_patterns.model_name
1308
+ new_model = new_model_patterns.model_name
1309
+ print(f"Did not find {old_model} in the table of content, so you will need to add {new_model} manually.")
1310
+ return
1311
+
1312
+ # Add the new model in the same toc
1313
+ toc_entry = {"local": f"model_doc/{new_model_patterns.model_type}", "title": new_model_patterns.model_name}
1314
+ model_doc[section_idx]["sections"].append(toc_entry)
1315
+ model_doc[section_idx]["sections"] = sorted(model_doc[section_idx]["sections"], key=lambda s: s["title"].lower())
1316
+ api_doc[model_idx]["sections"] = model_doc
1317
+ content[api_idx]["sections"] = api_doc
1318
+
1319
+ with open(toc_file, "w", encoding="utf-8") as f:
1320
+ f.write(yaml.dump(content, allow_unicode=True))
1321
+
1322
+
1323
+ def create_new_model_like(
1324
+ model_type: str,
1325
+ new_model_patterns: ModelPatterns,
1326
+ add_copied_from: bool = True,
1327
+ frameworks: Optional[List[str]] = None,
1328
+ old_checkpoint: Optional[str] = None,
1329
+ ):
1330
+ """
1331
+ Creates a new model module like a given model of the Transformers library.
1332
+
1333
+ Args:
1334
+ model_type (`str`): The model type to duplicate (like "bert" or "gpt2")
1335
+ new_model_patterns (`ModelPatterns`): The patterns for the new model.
1336
+ add_copied_from (`bool`, *optional*, defaults to `True`):
1337
+ Whether or not to add "Copied from" statements to all classes in the new model modeling files.
1338
+ frameworks (`List[str]`, *optional*):
1339
+ If passed, will limit the duplicate to the frameworks specified.
1340
+ old_checkpoint (`str`, *optional*):
1341
+ The name of the base checkpoint for the old model. Should be passed along when it can't be automatically
1342
+ recovered from the `model_type`.
1343
+ """
1344
+ # Retrieve all the old model info.
1345
+ model_info = retrieve_info_for_model(model_type, frameworks=frameworks)
1346
+ model_files = model_info["model_files"]
1347
+ old_model_patterns = model_info["model_patterns"]
1348
+ if old_checkpoint is not None:
1349
+ old_model_patterns.checkpoint = old_checkpoint
1350
+ if len(old_model_patterns.checkpoint) == 0:
1351
+ raise ValueError(
1352
+ "The old model checkpoint could not be recovered from the model type. Please pass it to the "
1353
+ "`old_checkpoint` argument."
1354
+ )
1355
+
1356
+ keep_old_processing = True
1357
+ for processing_attr in ["image_processor_class", "feature_extractor_class", "processor_class", "tokenizer_class"]:
1358
+ if getattr(old_model_patterns, processing_attr) != getattr(new_model_patterns, processing_attr):
1359
+ keep_old_processing = False
1360
+
1361
+ model_classes = model_info["model_classes"]
1362
+
1363
+ # 1. We create the module for our new model.
1364
+ old_module_name = model_files["module_name"]
1365
+ module_folder = TRANSFORMERS_PATH / "models" / new_model_patterns.model_lower_cased
1366
+ os.makedirs(module_folder, exist_ok=True)
1367
+
1368
+ files_to_adapt = model_files["model_files"]
1369
+ if keep_old_processing:
1370
+ files_to_adapt = [
1371
+ f
1372
+ for f in files_to_adapt
1373
+ if "tokenization" not in str(f)
1374
+ and "processing" not in str(f)
1375
+ and "feature_extraction" not in str(f)
1376
+ and "image_processing" not in str(f)
1377
+ ]
1378
+
1379
+ os.makedirs(module_folder, exist_ok=True)
1380
+ for module_file in files_to_adapt:
1381
+ new_module_name = module_file.name.replace(
1382
+ old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
1383
+ )
1384
+ dest_file = module_folder / new_module_name
1385
+ duplicate_module(
1386
+ module_file,
1387
+ old_model_patterns,
1388
+ new_model_patterns,
1389
+ dest_file=dest_file,
1390
+ add_copied_from=add_copied_from and "modeling" in new_module_name,
1391
+ )
1392
+
1393
+ clean_frameworks_in_init(
1394
+ module_folder / "__init__.py", frameworks=frameworks, keep_processing=not keep_old_processing
1395
+ )
1396
+
1397
+ # 2. We add our new model to the models init and the main init
1398
+ add_content_to_file(
1399
+ TRANSFORMERS_PATH / "models" / "__init__.py",
1400
+ f" {new_model_patterns.model_lower_cased},",
1401
+ add_after=f" {old_module_name},",
1402
+ exact_match=True,
1403
+ )
1404
+ add_model_to_main_init(
1405
+ old_model_patterns, new_model_patterns, frameworks=frameworks, with_processing=not keep_old_processing
1406
+ )
1407
+
1408
+ # 3. Add test files
1409
+ files_to_adapt = model_files["test_files"]
1410
+ if keep_old_processing:
1411
+ files_to_adapt = [
1412
+ f
1413
+ for f in files_to_adapt
1414
+ if "tokenization" not in str(f)
1415
+ and "processor" not in str(f)
1416
+ and "feature_extraction" not in str(f)
1417
+ and "image_processing" not in str(f)
1418
+ ]
1419
+
1420
+ def disable_fx_test(filename: Path) -> bool:
1421
+ with open(filename) as fp:
1422
+ content = fp.read()
1423
+ new_content = re.sub(r"fx_compatible\s*=\s*True", "fx_compatible = False", content)
1424
+ with open(filename, "w") as fp:
1425
+ fp.write(new_content)
1426
+ return content != new_content
1427
+
1428
+ disabled_fx_test = False
1429
+
1430
+ tests_folder = REPO_PATH / "tests" / "models" / new_model_patterns.model_lower_cased
1431
+ os.makedirs(tests_folder, exist_ok=True)
1432
+ with open(tests_folder / "__init__.py", "w"):
1433
+ pass
1434
+
1435
+ for test_file in files_to_adapt:
1436
+ new_test_file_name = test_file.name.replace(
1437
+ old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
1438
+ )
1439
+ dest_file = test_file.parent.parent / new_model_patterns.model_lower_cased / new_test_file_name
1440
+ duplicate_module(
1441
+ test_file,
1442
+ old_model_patterns,
1443
+ new_model_patterns,
1444
+ dest_file=dest_file,
1445
+ add_copied_from=False,
1446
+ attrs_to_remove=["pipeline_model_mapping", "is_pipeline_test_to_skip"],
1447
+ )
1448
+ disabled_fx_test = disabled_fx_test | disable_fx_test(dest_file)
1449
+
1450
+ if disabled_fx_test:
1451
+ print(
1452
+ "The tests for symbolic tracing with torch.fx were disabled, you can add those once symbolic tracing works"
1453
+ " for your new model."
1454
+ )
1455
+
1456
+ # 4. Add model to auto classes
1457
+ add_model_to_auto_classes(old_model_patterns, new_model_patterns, model_classes)
1458
+
1459
+ # 5. Add doc file
1460
+ doc_file = REPO_PATH / "docs" / "source" / "en" / "model_doc" / f"{old_model_patterns.model_type}.md"
1461
+ duplicate_doc_file(doc_file, old_model_patterns, new_model_patterns, frameworks=frameworks)
1462
+ insert_model_in_doc_toc(old_model_patterns, new_model_patterns)
1463
+
1464
+ # 6. Warn the user for duplicate patterns
1465
+ if old_model_patterns.model_type == old_model_patterns.checkpoint:
1466
+ print(
1467
+ "The model you picked has the same name for the model type and the checkpoint name "
1468
+ f"({old_model_patterns.model_type}). As a result, it's possible some places where the new checkpoint "
1469
+ f"should be, you have {new_model_patterns.model_type} instead. You should search for all instances of "
1470
+ f"{new_model_patterns.model_type} in the new files and check they're not badly used as checkpoints."
1471
+ )
1472
+ elif old_model_patterns.model_lower_cased == old_model_patterns.checkpoint:
1473
+ print(
1474
+ "The model you picked has the same name for the model type and the checkpoint name "
1475
+ f"({old_model_patterns.model_lower_cased}). As a result, it's possible some places where the new "
1476
+ f"checkpoint should be, you have {new_model_patterns.model_lower_cased} instead. You should search for "
1477
+ f"all instances of {new_model_patterns.model_lower_cased} in the new files and check they're not badly "
1478
+ "used as checkpoints."
1479
+ )
1480
+ if (
1481
+ old_model_patterns.model_type == old_model_patterns.model_lower_cased
1482
+ and new_model_patterns.model_type != new_model_patterns.model_lower_cased
1483
+ ):
1484
+ print(
1485
+ "The model you picked has the same name for the model type and the lowercased model name "
1486
+ f"({old_model_patterns.model_lower_cased}). As a result, it's possible some places where the new "
1487
+ f"model type should be, you have {new_model_patterns.model_lower_cased} instead. You should search for "
1488
+ f"all instances of {new_model_patterns.model_lower_cased} in the new files and check they're not badly "
1489
+ "used as the model type."
1490
+ )
1491
+
1492
+ if not keep_old_processing and old_model_patterns.tokenizer_class is not None:
1493
+ print(
1494
+ "The constants at the start of the new tokenizer file created needs to be manually fixed. If your new "
1495
+ "model has a tokenizer fast, you will also need to manually add the converter in the "
1496
+ "`SLOW_TO_FAST_CONVERTERS` constant of `convert_slow_tokenizer.py`."
1497
+ )
1498
+
1499
+
1500
+ def add_new_model_like_command_factory(args: Namespace):
1501
+ return AddNewModelLikeCommand(config_file=args.config_file, path_to_repo=args.path_to_repo)
1502
+
1503
+
1504
+ class AddNewModelLikeCommand(BaseTransformersCLICommand):
1505
+ @staticmethod
1506
+ def register_subcommand(parser: ArgumentParser):
1507
+ add_new_model_like_parser = parser.add_parser("add-new-model-like")
1508
+ add_new_model_like_parser.add_argument(
1509
+ "--config_file", type=str, help="A file with all the information for this model creation."
1510
+ )
1511
+ add_new_model_like_parser.add_argument(
1512
+ "--path_to_repo", type=str, help="When not using an editable install, the path to the Transformers repo."
1513
+ )
1514
+ add_new_model_like_parser.set_defaults(func=add_new_model_like_command_factory)
1515
+
1516
+ def __init__(self, config_file=None, path_to_repo=None, *args):
1517
+ if config_file is not None:
1518
+ with open(config_file, "r", encoding="utf-8") as f:
1519
+ config = json.load(f)
1520
+ self.old_model_type = config["old_model_type"]
1521
+ self.model_patterns = ModelPatterns(**config["new_model_patterns"])
1522
+ self.add_copied_from = config.get("add_copied_from", True)
1523
+ self.frameworks = config.get("frameworks", get_default_frameworks())
1524
+ self.old_checkpoint = config.get("old_checkpoint", None)
1525
+ else:
1526
+ (
1527
+ self.old_model_type,
1528
+ self.model_patterns,
1529
+ self.add_copied_from,
1530
+ self.frameworks,
1531
+ self.old_checkpoint,
1532
+ ) = get_user_input()
1533
+
1534
+ self.path_to_repo = path_to_repo
1535
+
1536
+ def run(self):
1537
+ if self.path_to_repo is not None:
1538
+ # Adapt constants
1539
+ global TRANSFORMERS_PATH
1540
+ global REPO_PATH
1541
+
1542
+ REPO_PATH = Path(self.path_to_repo)
1543
+ TRANSFORMERS_PATH = REPO_PATH / "src" / "transformers"
1544
+
1545
+ create_new_model_like(
1546
+ model_type=self.old_model_type,
1547
+ new_model_patterns=self.model_patterns,
1548
+ add_copied_from=self.add_copied_from,
1549
+ frameworks=self.frameworks,
1550
+ old_checkpoint=self.old_checkpoint,
1551
+ )
1552
+
1553
+
1554
+ def get_user_field(
1555
+ question: str,
1556
+ default_value: Optional[str] = None,
1557
+ is_valid_answer: Optional[Callable] = None,
1558
+ convert_to: Optional[Callable] = None,
1559
+ fallback_message: Optional[str] = None,
1560
+ ) -> Any:
1561
+ """
1562
+ A utility function that asks a question to the user to get an answer, potentially looping until it gets a valid
1563
+ answer.
1564
+
1565
+ Args:
1566
+ question (`str`): The question to ask the user.
1567
+ default_value (`str`, *optional*): A potential default value that will be used when the answer is empty.
1568
+ is_valid_answer (`Callable`, *optional*):
1569
+ If set, the question will be asked until this function returns `True` on the provided answer.
1570
+ convert_to (`Callable`, *optional*):
1571
+ If set, the answer will be passed to this function. If this function raises an error on the procided
1572
+ answer, the question will be asked again.
1573
+ fallback_message (`str`, *optional*):
1574
+ A message that will be displayed each time the question is asked again to the user.
1575
+
1576
+ Returns:
1577
+ `Any`: The answer provided by the user (or the default), passed through the potential conversion function.
1578
+ """
1579
+ if not question.endswith(" "):
1580
+ question = question + " "
1581
+ if default_value is not None:
1582
+ question = f"{question} [{default_value}] "
1583
+
1584
+ valid_answer = False
1585
+ while not valid_answer:
1586
+ answer = input(question)
1587
+ if default_value is not None and len(answer) == 0:
1588
+ answer = default_value
1589
+ if is_valid_answer is not None:
1590
+ valid_answer = is_valid_answer(answer)
1591
+ elif convert_to is not None:
1592
+ try:
1593
+ answer = convert_to(answer)
1594
+ valid_answer = True
1595
+ except Exception:
1596
+ valid_answer = False
1597
+ else:
1598
+ valid_answer = True
1599
+
1600
+ if not valid_answer:
1601
+ print(fallback_message)
1602
+
1603
+ return answer
1604
+
1605
+
1606
+ def convert_to_bool(x: str) -> bool:
1607
+ """
1608
+ Converts a string to a bool.
1609
+ """
1610
+ if x.lower() in ["1", "y", "yes", "true"]:
1611
+ return True
1612
+ if x.lower() in ["0", "n", "no", "false"]:
1613
+ return False
1614
+ raise ValueError(f"{x} is not a value that can be converted to a bool.")
1615
+
1616
+
1617
+ def get_user_input():
1618
+ """
1619
+ Ask the user for the necessary inputs to add the new model.
1620
+ """
1621
+ model_types = list(auto_module.configuration_auto.MODEL_NAMES_MAPPING.keys())
1622
+
1623
+ # Get old model type
1624
+ valid_model_type = False
1625
+ while not valid_model_type:
1626
+ old_model_type = input(
1627
+ "What is the model you would like to duplicate? Please provide the lowercase `model_type` (e.g. roberta): "
1628
+ )
1629
+ if old_model_type in model_types:
1630
+ valid_model_type = True
1631
+ else:
1632
+ print(f"{old_model_type} is not a valid model type.")
1633
+ near_choices = difflib.get_close_matches(old_model_type, model_types)
1634
+ if len(near_choices) >= 1:
1635
+ if len(near_choices) > 1:
1636
+ near_choices = " or ".join(near_choices)
1637
+ print(f"Did you mean {near_choices}?")
1638
+
1639
+ old_model_info = retrieve_info_for_model(old_model_type)
1640
+ old_tokenizer_class = old_model_info["model_patterns"].tokenizer_class
1641
+ old_image_processor_class = old_model_info["model_patterns"].image_processor_class
1642
+ old_feature_extractor_class = old_model_info["model_patterns"].feature_extractor_class
1643
+ old_processor_class = old_model_info["model_patterns"].processor_class
1644
+ old_frameworks = old_model_info["frameworks"]
1645
+
1646
+ old_checkpoint = None
1647
+ if len(old_model_info["model_patterns"].checkpoint) == 0:
1648
+ old_checkpoint = get_user_field(
1649
+ "We couldn't find the name of the base checkpoint for that model, please enter it here."
1650
+ )
1651
+
1652
+ model_name = get_user_field(
1653
+ "What is the name (with no special casing) for your new model in the paper (e.g. RoBERTa)? "
1654
+ )
1655
+ default_patterns = ModelPatterns(model_name, model_name)
1656
+
1657
+ model_type = get_user_field(
1658
+ "What identifier would you like to use for the `model_type` of this model? ",
1659
+ default_value=default_patterns.model_type,
1660
+ )
1661
+ model_lower_cased = get_user_field(
1662
+ "What lowercase name would you like to use for the module (folder) of this model? ",
1663
+ default_value=default_patterns.model_lower_cased,
1664
+ )
1665
+ model_camel_cased = get_user_field(
1666
+ "What prefix (camel-cased) would you like to use for the model classes of this model (e.g. Roberta)? ",
1667
+ default_value=default_patterns.model_camel_cased,
1668
+ )
1669
+ model_upper_cased = get_user_field(
1670
+ "What prefix (upper-cased) would you like to use for the constants relative to this model? ",
1671
+ default_value=default_patterns.model_upper_cased,
1672
+ )
1673
+ config_class = get_user_field(
1674
+ "What will be the name of the config class for this model? ", default_value=f"{model_camel_cased}Config"
1675
+ )
1676
+ checkpoint = get_user_field(
1677
+ "Please give a checkpoint identifier (on the model Hub) for this new model (e.g. facebook/roberta-base): "
1678
+ )
1679
+
1680
+ old_processing_classes = [
1681
+ c
1682
+ for c in [old_image_processor_class, old_feature_extractor_class, old_tokenizer_class, old_processor_class]
1683
+ if c is not None
1684
+ ]
1685
+ old_processing_classes = ", ".join(old_processing_classes)
1686
+ keep_processing = get_user_field(
1687
+ f"Will your new model use the same processing class as {old_model_type} ({old_processing_classes}) (yes/no)? ",
1688
+ convert_to=convert_to_bool,
1689
+ fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
1690
+ )
1691
+ if keep_processing:
1692
+ image_processor_class = old_image_processor_class
1693
+ feature_extractor_class = old_feature_extractor_class
1694
+ processor_class = old_processor_class
1695
+ tokenizer_class = old_tokenizer_class
1696
+ else:
1697
+ if old_tokenizer_class is not None:
1698
+ tokenizer_class = get_user_field(
1699
+ "What will be the name of the tokenizer class for this model? ",
1700
+ default_value=f"{model_camel_cased}Tokenizer",
1701
+ )
1702
+ else:
1703
+ tokenizer_class = None
1704
+ if old_image_processor_class is not None:
1705
+ image_processor_class = get_user_field(
1706
+ "What will be the name of the image processor class for this model? ",
1707
+ default_value=f"{model_camel_cased}ImageProcessor",
1708
+ )
1709
+ else:
1710
+ image_processor_class = None
1711
+ if old_feature_extractor_class is not None:
1712
+ feature_extractor_class = get_user_field(
1713
+ "What will be the name of the feature extractor class for this model? ",
1714
+ default_value=f"{model_camel_cased}FeatureExtractor",
1715
+ )
1716
+ else:
1717
+ feature_extractor_class = None
1718
+ if old_processor_class is not None:
1719
+ processor_class = get_user_field(
1720
+ "What will be the name of the processor class for this model? ",
1721
+ default_value=f"{model_camel_cased}Processor",
1722
+ )
1723
+ else:
1724
+ processor_class = None
1725
+
1726
+ model_patterns = ModelPatterns(
1727
+ model_name,
1728
+ checkpoint,
1729
+ model_type=model_type,
1730
+ model_lower_cased=model_lower_cased,
1731
+ model_camel_cased=model_camel_cased,
1732
+ model_upper_cased=model_upper_cased,
1733
+ config_class=config_class,
1734
+ tokenizer_class=tokenizer_class,
1735
+ image_processor_class=image_processor_class,
1736
+ feature_extractor_class=feature_extractor_class,
1737
+ processor_class=processor_class,
1738
+ )
1739
+
1740
+ add_copied_from = get_user_field(
1741
+ "Should we add # Copied from statements when creating the new modeling file (yes/no)? ",
1742
+ convert_to=convert_to_bool,
1743
+ default_value="yes",
1744
+ fallback_message="Please answer yes/no, y/n, true/false or 1/0.",
1745
+ )
1746
+
1747
+ all_frameworks = get_user_field(
1748
+ "Should we add a version of your new model in all the frameworks implemented by"
1749
+ f" {old_model_type} ({old_frameworks}) (yes/no)? ",
1750
+ convert_to=convert_to_bool,
1751
+ default_value="yes",
1752
+ fallback_message="Please answer yes/no, y/n, true/false or 1/0.",
1753
+ )
1754
+ if all_frameworks:
1755
+ frameworks = None
1756
+ else:
1757
+ frameworks = get_user_field(
1758
+ "Please enter the list of framworks you want (pt, tf, flax) separated by spaces",
1759
+ is_valid_answer=lambda x: all(p in ["pt", "tf", "flax"] for p in x.split(" ")),
1760
+ )
1761
+ frameworks = list(set(frameworks.split(" ")))
1762
+
1763
+ return (old_model_type, model_patterns, add_copied_from, frameworks, old_checkpoint)
transformers_4_35_0/commands/convert.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from argparse import ArgumentParser, Namespace
16
+
17
+ from ..utils import logging
18
+ from . import BaseTransformersCLICommand
19
+
20
+
21
+ def convert_command_factory(args: Namespace):
22
+ """
23
+ Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint.
24
+
25
+ Returns: ServeCommand
26
+ """
27
+ return ConvertCommand(
28
+ args.model_type, args.tf_checkpoint, args.pytorch_dump_output, args.config, args.finetuning_task_name
29
+ )
30
+
31
+
32
+ IMPORT_ERROR_MESSAGE = """
33
+ transformers can only be used from the commandline to convert TensorFlow models in PyTorch, In that case, it requires
34
+ TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions.
35
+ """
36
+
37
+
38
+ class ConvertCommand(BaseTransformersCLICommand):
39
+ @staticmethod
40
+ def register_subcommand(parser: ArgumentParser):
41
+ """
42
+ Register this command to argparse so it's available for the transformer-cli
43
+
44
+ Args:
45
+ parser: Root parser to register command-specific arguments
46
+ """
47
+ train_parser = parser.add_parser(
48
+ "convert",
49
+ help="CLI tool to run convert model from original author checkpoints to Transformers PyTorch checkpoints.",
50
+ )
51
+ train_parser.add_argument("--model_type", type=str, required=True, help="Model's type.")
52
+ train_parser.add_argument(
53
+ "--tf_checkpoint", type=str, required=True, help="TensorFlow checkpoint path or folder."
54
+ )
55
+ train_parser.add_argument(
56
+ "--pytorch_dump_output", type=str, required=True, help="Path to the PyTorch saved model output."
57
+ )
58
+ train_parser.add_argument("--config", type=str, default="", help="Configuration file path or folder.")
59
+ train_parser.add_argument(
60
+ "--finetuning_task_name",
61
+ type=str,
62
+ default=None,
63
+ help="Optional fine-tuning task name if the TF model was a finetuned model.",
64
+ )
65
+ train_parser.set_defaults(func=convert_command_factory)
66
+
67
+ def __init__(
68
+ self,
69
+ model_type: str,
70
+ tf_checkpoint: str,
71
+ pytorch_dump_output: str,
72
+ config: str,
73
+ finetuning_task_name: str,
74
+ *args,
75
+ ):
76
+ self._logger = logging.get_logger("transformers-cli/converting")
77
+
78
+ self._logger.info(f"Loading model {model_type}")
79
+ self._model_type = model_type
80
+ self._tf_checkpoint = tf_checkpoint
81
+ self._pytorch_dump_output = pytorch_dump_output
82
+ self._config = config
83
+ self._finetuning_task_name = finetuning_task_name
84
+
85
+ def run(self):
86
+ if self._model_type == "albert":
87
+ try:
88
+ from ..models.albert.convert_albert_original_tf_checkpoint_to_pytorch import (
89
+ convert_tf_checkpoint_to_pytorch,
90
+ )
91
+ except ImportError:
92
+ raise ImportError(IMPORT_ERROR_MESSAGE)
93
+
94
+ convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
95
+ elif self._model_type == "bert":
96
+ try:
97
+ from ..models.bert.convert_bert_original_tf_checkpoint_to_pytorch import (
98
+ convert_tf_checkpoint_to_pytorch,
99
+ )
100
+ except ImportError:
101
+ raise ImportError(IMPORT_ERROR_MESSAGE)
102
+
103
+ convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
104
+ elif self._model_type == "funnel":
105
+ try:
106
+ from ..models.funnel.convert_funnel_original_tf_checkpoint_to_pytorch import (
107
+ convert_tf_checkpoint_to_pytorch,
108
+ )
109
+ except ImportError:
110
+ raise ImportError(IMPORT_ERROR_MESSAGE)
111
+
112
+ convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
113
+ elif self._model_type == "t5":
114
+ try:
115
+ from ..models.t5.convert_t5_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
116
+ except ImportError:
117
+ raise ImportError(IMPORT_ERROR_MESSAGE)
118
+
119
+ convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
120
+ elif self._model_type == "gpt":
121
+ from ..models.openai.convert_openai_original_tf_checkpoint_to_pytorch import (
122
+ convert_openai_checkpoint_to_pytorch,
123
+ )
124
+
125
+ convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
126
+ elif self._model_type == "transfo_xl":
127
+ try:
128
+ from ..models.transfo_xl.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
129
+ convert_transfo_xl_checkpoint_to_pytorch,
130
+ )
131
+ except ImportError:
132
+ raise ImportError(IMPORT_ERROR_MESSAGE)
133
+
134
+ if "ckpt" in self._tf_checkpoint.lower():
135
+ TF_CHECKPOINT = self._tf_checkpoint
136
+ TF_DATASET_FILE = ""
137
+ else:
138
+ TF_DATASET_FILE = self._tf_checkpoint
139
+ TF_CHECKPOINT = ""
140
+ convert_transfo_xl_checkpoint_to_pytorch(
141
+ TF_CHECKPOINT, self._config, self._pytorch_dump_output, TF_DATASET_FILE
142
+ )
143
+ elif self._model_type == "gpt2":
144
+ try:
145
+ from ..models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import (
146
+ convert_gpt2_checkpoint_to_pytorch,
147
+ )
148
+ except ImportError:
149
+ raise ImportError(IMPORT_ERROR_MESSAGE)
150
+
151
+ convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
152
+ elif self._model_type == "xlnet":
153
+ try:
154
+ from ..models.xlnet.convert_xlnet_original_tf_checkpoint_to_pytorch import (
155
+ convert_xlnet_checkpoint_to_pytorch,
156
+ )
157
+ except ImportError:
158
+ raise ImportError(IMPORT_ERROR_MESSAGE)
159
+
160
+ convert_xlnet_checkpoint_to_pytorch(
161
+ self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
162
+ )
163
+ elif self._model_type == "xlm":
164
+ from ..models.xlm.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
165
+ convert_xlm_checkpoint_to_pytorch,
166
+ )
167
+
168
+ convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
169
+ elif self._model_type == "lxmert":
170
+ from ..models.lxmert.convert_lxmert_original_tf_checkpoint_to_pytorch import (
171
+ convert_lxmert_checkpoint_to_pytorch,
172
+ )
173
+
174
+ convert_lxmert_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
175
+ elif self._model_type == "rembert":
176
+ from ..models.rembert.convert_rembert_tf_checkpoint_to_pytorch import (
177
+ convert_rembert_tf_checkpoint_to_pytorch,
178
+ )
179
+
180
+ convert_rembert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
181
+ else:
182
+ raise ValueError(
183
+ "--model_type should be selected in the list [bert, gpt, gpt2, t5, transfo_xl, xlnet, xlm, lxmert]"
184
+ )
transformers_4_35_0/commands/download.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from argparse import ArgumentParser
16
+
17
+ from . import BaseTransformersCLICommand
18
+
19
+
20
+ def download_command_factory(args):
21
+ return DownloadCommand(args.model, args.cache_dir, args.force, args.trust_remote_code)
22
+
23
+
24
+ class DownloadCommand(BaseTransformersCLICommand):
25
+ @staticmethod
26
+ def register_subcommand(parser: ArgumentParser):
27
+ download_parser = parser.add_parser("download")
28
+ download_parser.add_argument(
29
+ "--cache-dir", type=str, default=None, help="Path to location to store the models"
30
+ )
31
+ download_parser.add_argument(
32
+ "--force", action="store_true", help="Force the model to be download even if already in cache-dir"
33
+ )
34
+ download_parser.add_argument(
35
+ "--trust-remote-code",
36
+ action="store_true",
37
+ help="Whether or not to allow for custom models defined on the Hub in their own modeling files. Use only if you've reviewed the code as it will execute on your local machine",
38
+ )
39
+ download_parser.add_argument("model", type=str, help="Name of the model to download")
40
+ download_parser.set_defaults(func=download_command_factory)
41
+
42
+ def __init__(self, model: str, cache: str, force: bool, trust_remote_code: bool):
43
+ self._model = model
44
+ self._cache = cache
45
+ self._force = force
46
+ self._trust_remote_code = trust_remote_code
47
+
48
+ def run(self):
49
+ from ..models.auto import AutoModel, AutoTokenizer
50
+
51
+ AutoModel.from_pretrained(
52
+ self._model, cache_dir=self._cache, force_download=self._force, trust_remote_code=self._trust_remote_code
53
+ )
54
+ AutoTokenizer.from_pretrained(
55
+ self._model, cache_dir=self._cache, force_download=self._force, trust_remote_code=self._trust_remote_code
56
+ )
transformers_4_35_0/commands/env.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib.util
16
+ import os
17
+ import platform
18
+ from argparse import ArgumentParser
19
+
20
+ import huggingface_hub
21
+
22
+ from .. import __version__ as version
23
+ from ..utils import (
24
+ is_accelerate_available,
25
+ is_flax_available,
26
+ is_safetensors_available,
27
+ is_tf_available,
28
+ is_torch_available,
29
+ )
30
+ from . import BaseTransformersCLICommand
31
+
32
+
33
+ def info_command_factory(_):
34
+ return EnvironmentCommand()
35
+
36
+
37
+ def download_command_factory(args):
38
+ return EnvironmentCommand(args.accelerate_config_file)
39
+
40
+
41
+ class EnvironmentCommand(BaseTransformersCLICommand):
42
+ @staticmethod
43
+ def register_subcommand(parser: ArgumentParser):
44
+ download_parser = parser.add_parser("env")
45
+ download_parser.set_defaults(func=info_command_factory)
46
+ download_parser.add_argument(
47
+ "--accelerate-config_file",
48
+ default=None,
49
+ help="The accelerate config file to use for the default values in the launching script.",
50
+ )
51
+ download_parser.set_defaults(func=download_command_factory)
52
+
53
+ def __init__(self, accelerate_config_file, *args) -> None:
54
+ self._accelerate_config_file = accelerate_config_file
55
+
56
+ def run(self):
57
+ safetensors_version = "not installed"
58
+ if is_safetensors_available():
59
+ import safetensors
60
+
61
+ safetensors_version = safetensors.__version__
62
+ elif importlib.util.find_spec("safetensors") is not None:
63
+ import safetensors
64
+
65
+ safetensors_version = f"{safetensors.__version__} but is ignored because of PyTorch version too old."
66
+
67
+ accelerate_version = "not installed"
68
+ accelerate_config = accelerate_config_str = "not found"
69
+ if is_accelerate_available():
70
+ import accelerate
71
+ from accelerate.commands.config import default_config_file, load_config_from_file
72
+
73
+ accelerate_version = accelerate.__version__
74
+ # Get the default from the config file.
75
+ if self._accelerate_config_file is not None or os.path.isfile(default_config_file):
76
+ accelerate_config = load_config_from_file(self._accelerate_config_file).to_dict()
77
+
78
+ accelerate_config_str = (
79
+ "\n".join([f"\t- {prop}: {val}" for prop, val in accelerate_config.items()])
80
+ if isinstance(accelerate_config, dict)
81
+ else f"\t{accelerate_config}"
82
+ )
83
+
84
+ pt_version = "not installed"
85
+ pt_cuda_available = "NA"
86
+ if is_torch_available():
87
+ import torch
88
+
89
+ pt_version = torch.__version__
90
+ pt_cuda_available = torch.cuda.is_available()
91
+
92
+ tf_version = "not installed"
93
+ tf_cuda_available = "NA"
94
+ if is_tf_available():
95
+ import tensorflow as tf
96
+
97
+ tf_version = tf.__version__
98
+ try:
99
+ # deprecated in v2.1
100
+ tf_cuda_available = tf.test.is_gpu_available()
101
+ except AttributeError:
102
+ # returns list of devices, convert to bool
103
+ tf_cuda_available = bool(tf.config.list_physical_devices("GPU"))
104
+
105
+ flax_version = "not installed"
106
+ jax_version = "not installed"
107
+ jaxlib_version = "not installed"
108
+ jax_backend = "NA"
109
+ if is_flax_available():
110
+ import flax
111
+ import jax
112
+ import jaxlib
113
+
114
+ flax_version = flax.__version__
115
+ jax_version = jax.__version__
116
+ jaxlib_version = jaxlib.__version__
117
+ jax_backend = jax.lib.xla_bridge.get_backend().platform
118
+
119
+ info = {
120
+ "`transformers` version": version,
121
+ "Platform": platform.platform(),
122
+ "Python version": platform.python_version(),
123
+ "Huggingface_hub version": huggingface_hub.__version__,
124
+ "Safetensors version": f"{safetensors_version}",
125
+ "Accelerate version": f"{accelerate_version}",
126
+ "Accelerate config": f"{accelerate_config_str}",
127
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
128
+ "Tensorflow version (GPU?)": f"{tf_version} ({tf_cuda_available})",
129
+ "Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})",
130
+ "Jax version": f"{jax_version}",
131
+ "JaxLib version": f"{jaxlib_version}",
132
+ "Using GPU in script?": "<fill in>",
133
+ "Using distributed or parallel set-up in script?": "<fill in>",
134
+ }
135
+
136
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
137
+ print(self.format_dict(info))
138
+
139
+ return info
140
+
141
+ @staticmethod
142
+ def format_dict(d):
143
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
transformers_4_35_0/commands/lfs.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of a custom transfer agent for the transfer type "multipart" for git-lfs.
3
+
4
+ Inspired by: github.com/cbartz/git-lfs-swift-transfer-agent/blob/master/git_lfs_swift_transfer.py
5
+
6
+ Spec is: github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md
7
+
8
+
9
+ To launch debugger while developing:
10
+
11
+ ``` [lfs "customtransfer.multipart"]
12
+ path = /path/to/transformers/.env/bin/python args = -m debugpy --listen 5678 --wait-for-client
13
+ /path/to/transformers/src/transformers/commands/transformers_cli.py lfs-multipart-upload ```"""
14
+
15
+ import json
16
+ import os
17
+ import subprocess
18
+ import sys
19
+ import warnings
20
+ from argparse import ArgumentParser
21
+ from contextlib import AbstractContextManager
22
+ from typing import Dict, List, Optional
23
+
24
+ import requests
25
+
26
+ from ..utils import logging
27
+ from . import BaseTransformersCLICommand
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ LFS_MULTIPART_UPLOAD_COMMAND = "lfs-multipart-upload"
34
+
35
+
36
+ class LfsCommands(BaseTransformersCLICommand):
37
+ """
38
+ Implementation of a custom transfer agent for the transfer type "multipart" for git-lfs. This lets users upload
39
+ large files >5GB 🔥. Spec for LFS custom transfer agent is:
40
+ https://github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md
41
+
42
+ This introduces two commands to the CLI:
43
+
44
+ 1. $ transformers-cli lfs-enable-largefiles
45
+
46
+ This should be executed once for each model repo that contains a model file >5GB. It's documented in the error
47
+ message you get if you just try to git push a 5GB file without having enabled it before.
48
+
49
+ 2. $ transformers-cli lfs-multipart-upload
50
+
51
+ This command is called by lfs directly and is not meant to be called by the user.
52
+ """
53
+
54
+ @staticmethod
55
+ def register_subcommand(parser: ArgumentParser):
56
+ enable_parser = parser.add_parser(
57
+ "lfs-enable-largefiles",
58
+ help=(
59
+ "Deprecated: use `huggingface-cli` instead. Configure your repository to enable upload of files > 5GB."
60
+ ),
61
+ )
62
+ enable_parser.add_argument("path", type=str, help="Local path to repository you want to configure.")
63
+ enable_parser.set_defaults(func=lambda args: LfsEnableCommand(args))
64
+
65
+ upload_parser = parser.add_parser(
66
+ LFS_MULTIPART_UPLOAD_COMMAND,
67
+ help=(
68
+ "Deprecated: use `huggingface-cli` instead. "
69
+ "Command will get called by git-lfs, do not call it directly."
70
+ ),
71
+ )
72
+ upload_parser.set_defaults(func=lambda args: LfsUploadCommand(args))
73
+
74
+
75
+ class LfsEnableCommand:
76
+ def __init__(self, args):
77
+ self.args = args
78
+
79
+ def run(self):
80
+ warnings.warn(
81
+ "Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead."
82
+ )
83
+ local_path = os.path.abspath(self.args.path)
84
+ if not os.path.isdir(local_path):
85
+ print("This does not look like a valid git repo.")
86
+ exit(1)
87
+ subprocess.run(
88
+ "git config lfs.customtransfer.multipart.path transformers-cli".split(), check=True, cwd=local_path
89
+ )
90
+ subprocess.run(
91
+ f"git config lfs.customtransfer.multipart.args {LFS_MULTIPART_UPLOAD_COMMAND}".split(),
92
+ check=True,
93
+ cwd=local_path,
94
+ )
95
+ print("Local repo set up for largefiles")
96
+
97
+
98
+ def write_msg(msg: Dict):
99
+ """Write out the message in Line delimited JSON."""
100
+ msg = json.dumps(msg) + "\n"
101
+ sys.stdout.write(msg)
102
+ sys.stdout.flush()
103
+
104
+
105
+ def read_msg() -> Optional[Dict]:
106
+ """Read Line delimited JSON from stdin."""
107
+ msg = json.loads(sys.stdin.readline().strip())
108
+
109
+ if "terminate" in (msg.get("type"), msg.get("event")):
110
+ # terminate message received
111
+ return None
112
+
113
+ if msg.get("event") not in ("download", "upload"):
114
+ logger.critical("Received unexpected message")
115
+ sys.exit(1)
116
+
117
+ return msg
118
+
119
+
120
+ class FileSlice(AbstractContextManager):
121
+ """
122
+ File-like object that only reads a slice of a file
123
+
124
+ Inspired by stackoverflow.com/a/29838711/593036
125
+ """
126
+
127
+ def __init__(self, filepath: str, seek_from: int, read_limit: int):
128
+ self.filepath = filepath
129
+ self.seek_from = seek_from
130
+ self.read_limit = read_limit
131
+ self.n_seen = 0
132
+
133
+ def __enter__(self):
134
+ self.f = open(self.filepath, "rb")
135
+ self.f.seek(self.seek_from)
136
+ return self
137
+
138
+ def __len__(self):
139
+ total_length = os.fstat(self.f.fileno()).st_size
140
+ return min(self.read_limit, total_length - self.seek_from)
141
+
142
+ def read(self, n=-1):
143
+ if self.n_seen >= self.read_limit:
144
+ return b""
145
+ remaining_amount = self.read_limit - self.n_seen
146
+ data = self.f.read(remaining_amount if n < 0 else min(n, remaining_amount))
147
+ self.n_seen += len(data)
148
+ return data
149
+
150
+ def __iter__(self):
151
+ yield self.read(n=4 * 1024 * 1024)
152
+
153
+ def __exit__(self, *args):
154
+ self.f.close()
155
+
156
+
157
+ class LfsUploadCommand:
158
+ def __init__(self, args):
159
+ self.args = args
160
+
161
+ def run(self):
162
+ # Immediately after invoking a custom transfer process, git-lfs
163
+ # sends initiation data to the process over stdin.
164
+ # This tells the process useful information about the configuration.
165
+ init_msg = json.loads(sys.stdin.readline().strip())
166
+ if not (init_msg.get("event") == "init" and init_msg.get("operation") == "upload"):
167
+ write_msg({"error": {"code": 32, "message": "Wrong lfs init operation"}})
168
+ sys.exit(1)
169
+
170
+ # The transfer process should use the information it needs from the
171
+ # initiation structure, and also perform any one-off setup tasks it
172
+ # needs to do. It should then respond on stdout with a simple empty
173
+ # confirmation structure, as follows:
174
+ write_msg({})
175
+
176
+ # After the initiation exchange, git-lfs will send any number of
177
+ # transfer requests to the stdin of the transfer process, in a serial sequence.
178
+ while True:
179
+ msg = read_msg()
180
+ if msg is None:
181
+ # When all transfers have been processed, git-lfs will send
182
+ # a terminate event to the stdin of the transfer process.
183
+ # On receiving this message the transfer process should
184
+ # clean up and terminate. No response is expected.
185
+ sys.exit(0)
186
+
187
+ oid = msg["oid"]
188
+ filepath = msg["path"]
189
+ completion_url = msg["action"]["href"]
190
+ header = msg["action"]["header"]
191
+ chunk_size = int(header.pop("chunk_size"))
192
+ presigned_urls: List[str] = list(header.values())
193
+
194
+ parts = []
195
+ for i, presigned_url in enumerate(presigned_urls):
196
+ with FileSlice(filepath, seek_from=i * chunk_size, read_limit=chunk_size) as data:
197
+ r = requests.put(presigned_url, data=data)
198
+ r.raise_for_status()
199
+ parts.append(
200
+ {
201
+ "etag": r.headers.get("etag"),
202
+ "partNumber": i + 1,
203
+ }
204
+ )
205
+ # In order to support progress reporting while data is uploading / downloading,
206
+ # the transfer process should post messages to stdout
207
+ write_msg(
208
+ {
209
+ "event": "progress",
210
+ "oid": oid,
211
+ "bytesSoFar": (i + 1) * chunk_size,
212
+ "bytesSinceLast": chunk_size,
213
+ }
214
+ )
215
+ # Not precise but that's ok.
216
+
217
+ r = requests.post(
218
+ completion_url,
219
+ json={
220
+ "oid": oid,
221
+ "parts": parts,
222
+ },
223
+ )
224
+ r.raise_for_status()
225
+
226
+ write_msg({"event": "complete", "oid": oid})
transformers_4_35_0/commands/pt_to_tf.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import os
17
+ from argparse import ArgumentParser, Namespace
18
+ from importlib import import_module
19
+
20
+ import huggingface_hub
21
+ import numpy as np
22
+ from packaging import version
23
+
24
+ from .. import (
25
+ FEATURE_EXTRACTOR_MAPPING,
26
+ IMAGE_PROCESSOR_MAPPING,
27
+ PROCESSOR_MAPPING,
28
+ TOKENIZER_MAPPING,
29
+ AutoConfig,
30
+ AutoFeatureExtractor,
31
+ AutoImageProcessor,
32
+ AutoProcessor,
33
+ AutoTokenizer,
34
+ is_datasets_available,
35
+ is_tf_available,
36
+ is_torch_available,
37
+ )
38
+ from ..utils import TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
39
+ from . import BaseTransformersCLICommand
40
+
41
+
42
+ if is_tf_available():
43
+ import tensorflow as tf
44
+
45
+ tf.config.experimental.enable_tensor_float_32_execution(False)
46
+
47
+ if is_torch_available():
48
+ import torch
49
+
50
+ if is_datasets_available():
51
+ from datasets import load_dataset
52
+
53
+
54
+ MAX_ERROR = 5e-5 # larger error tolerance than in our internal tests, to avoid flaky user-facing errors
55
+
56
+
57
+ def convert_command_factory(args: Namespace):
58
+ """
59
+ Factory function used to convert a model PyTorch checkpoint in a TensorFlow 2 checkpoint.
60
+
61
+ Returns: ServeCommand
62
+ """
63
+ return PTtoTFCommand(
64
+ args.model_name,
65
+ args.local_dir,
66
+ args.max_error,
67
+ args.new_weights,
68
+ args.no_pr,
69
+ args.push,
70
+ args.extra_commit_description,
71
+ args.override_model_class,
72
+ )
73
+
74
+
75
+ class PTtoTFCommand(BaseTransformersCLICommand):
76
+ @staticmethod
77
+ def register_subcommand(parser: ArgumentParser):
78
+ """
79
+ Register this command to argparse so it's available for the transformer-cli
80
+
81
+ Args:
82
+ parser: Root parser to register command-specific arguments
83
+ """
84
+ train_parser = parser.add_parser(
85
+ "pt-to-tf",
86
+ help=(
87
+ "CLI tool to run convert a transformers model from a PyTorch checkpoint to a TensorFlow checkpoint."
88
+ " Can also be used to validate existing weights without opening PRs, with --no-pr."
89
+ ),
90
+ )
91
+ train_parser.add_argument(
92
+ "--model-name",
93
+ type=str,
94
+ required=True,
95
+ help="The model name, including owner/organization, as seen on the hub.",
96
+ )
97
+ train_parser.add_argument(
98
+ "--local-dir",
99
+ type=str,
100
+ default="",
101
+ help="Optional local directory of the model repository. Defaults to /tmp/{model_name}",
102
+ )
103
+ train_parser.add_argument(
104
+ "--max-error",
105
+ type=float,
106
+ default=MAX_ERROR,
107
+ help=(
108
+ f"Maximum error tolerance. Defaults to {MAX_ERROR}. This flag should be avoided, use at your own risk."
109
+ ),
110
+ )
111
+ train_parser.add_argument(
112
+ "--new-weights",
113
+ action="store_true",
114
+ help="Optional flag to create new TensorFlow weights, even if they already exist.",
115
+ )
116
+ train_parser.add_argument(
117
+ "--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights."
118
+ )
119
+ train_parser.add_argument(
120
+ "--push",
121
+ action="store_true",
122
+ help="Optional flag to push the weights directly to `main` (requires permissions)",
123
+ )
124
+ train_parser.add_argument(
125
+ "--extra-commit-description",
126
+ type=str,
127
+ default="",
128
+ help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).",
129
+ )
130
+ train_parser.add_argument(
131
+ "--override-model-class",
132
+ type=str,
133
+ default=None,
134
+ help="If you think you know better than the auto-detector, you can specify the model class here. "
135
+ "Can be either an AutoModel class or a specific model class like BertForSequenceClassification.",
136
+ )
137
+ train_parser.set_defaults(func=convert_command_factory)
138
+
139
+ @staticmethod
140
+ def find_pt_tf_differences(pt_outputs, tf_outputs):
141
+ """
142
+ Compares the TensorFlow and PyTorch outputs, returning a dictionary with all tensor differences.
143
+ """
144
+ # 1. All output attributes must be the same
145
+ pt_out_attrs = set(pt_outputs.keys())
146
+ tf_out_attrs = set(tf_outputs.keys())
147
+ if pt_out_attrs != tf_out_attrs:
148
+ raise ValueError(
149
+ f"The model outputs have different attributes, aborting. (Pytorch: {pt_out_attrs}, TensorFlow:"
150
+ f" {tf_out_attrs})"
151
+ )
152
+
153
+ # 2. For each output attribute, computes the difference
154
+ def _find_pt_tf_differences(pt_out, tf_out, differences, attr_name=""):
155
+ # If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in
156
+ # recursivelly, keeping the name of the attribute.
157
+ if isinstance(pt_out, torch.Tensor):
158
+ tensor_difference = np.max(np.abs(pt_out.numpy() - tf_out.numpy()))
159
+ differences[attr_name] = tensor_difference
160
+ else:
161
+ root_name = attr_name
162
+ for i, pt_item in enumerate(pt_out):
163
+ # If it is a named attribute, we keep the name. Otherwise, just its index.
164
+ if isinstance(pt_item, str):
165
+ branch_name = root_name + pt_item
166
+ tf_item = tf_out[pt_item]
167
+ pt_item = pt_out[pt_item]
168
+ else:
169
+ branch_name = root_name + f"[{i}]"
170
+ tf_item = tf_out[i]
171
+ differences = _find_pt_tf_differences(pt_item, tf_item, differences, branch_name)
172
+
173
+ return differences
174
+
175
+ return _find_pt_tf_differences(pt_outputs, tf_outputs, {})
176
+
177
+ def __init__(
178
+ self,
179
+ model_name: str,
180
+ local_dir: str,
181
+ max_error: float,
182
+ new_weights: bool,
183
+ no_pr: bool,
184
+ push: bool,
185
+ extra_commit_description: str,
186
+ override_model_class: str,
187
+ *args,
188
+ ):
189
+ self._logger = logging.get_logger("transformers-cli/pt_to_tf")
190
+ self._model_name = model_name
191
+ self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
192
+ self._max_error = max_error
193
+ self._new_weights = new_weights
194
+ self._no_pr = no_pr
195
+ self._push = push
196
+ self._extra_commit_description = extra_commit_description
197
+ self._override_model_class = override_model_class
198
+
199
+ def get_inputs(self, pt_model, tf_dummy_inputs, config):
200
+ """
201
+ Returns the right inputs for the model, based on its signature.
202
+ """
203
+
204
+ def _get_audio_input():
205
+ ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
206
+ speech_samples = ds.sort("id").select(range(2))[:2]["audio"]
207
+ raw_samples = [x["array"] for x in speech_samples]
208
+ return raw_samples
209
+
210
+ model_config_class = type(pt_model.config)
211
+ if model_config_class in PROCESSOR_MAPPING:
212
+ processor = AutoProcessor.from_pretrained(self._local_dir)
213
+ if model_config_class in TOKENIZER_MAPPING and processor.tokenizer.pad_token is None:
214
+ processor.tokenizer.pad_token = processor.tokenizer.eos_token
215
+ elif model_config_class in IMAGE_PROCESSOR_MAPPING:
216
+ processor = AutoImageProcessor.from_pretrained(self._local_dir)
217
+ elif model_config_class in FEATURE_EXTRACTOR_MAPPING:
218
+ processor = AutoFeatureExtractor.from_pretrained(self._local_dir)
219
+ elif model_config_class in TOKENIZER_MAPPING:
220
+ processor = AutoTokenizer.from_pretrained(self._local_dir)
221
+ if processor.pad_token is None:
222
+ processor.pad_token = processor.eos_token
223
+ else:
224
+ raise ValueError(f"Unknown data processing type (model config type: {model_config_class})")
225
+
226
+ model_forward_signature = set(inspect.signature(pt_model.forward).parameters.keys())
227
+ processor_inputs = {}
228
+ if "input_ids" in model_forward_signature:
229
+ processor_inputs.update(
230
+ {
231
+ "text": ["Hi there!", "I am a batch with more than one row and different input lengths."],
232
+ "padding": True,
233
+ "truncation": True,
234
+ }
235
+ )
236
+ if "pixel_values" in model_forward_signature:
237
+ sample_images = load_dataset("cifar10", "plain_text", split="test")[:2]["img"]
238
+ processor_inputs.update({"images": sample_images})
239
+ if "input_features" in model_forward_signature:
240
+ feature_extractor_signature = inspect.signature(processor.feature_extractor).parameters
241
+ # Pad to the largest input length by default but take feature extractor default
242
+ # padding value if it exists e.g. "max_length" and is not False or None
243
+ if "padding" in feature_extractor_signature:
244
+ default_strategy = feature_extractor_signature["padding"].default
245
+ if default_strategy is not False and default_strategy is not None:
246
+ padding_strategy = default_strategy
247
+ else:
248
+ padding_strategy = True
249
+ else:
250
+ padding_strategy = True
251
+ processor_inputs.update({"audio": _get_audio_input(), "padding": padding_strategy})
252
+ if "input_values" in model_forward_signature: # Wav2Vec2 audio input
253
+ processor_inputs.update({"audio": _get_audio_input(), "padding": True})
254
+ pt_input = processor(**processor_inputs, return_tensors="pt")
255
+ tf_input = processor(**processor_inputs, return_tensors="tf")
256
+
257
+ # Extra input requirements, in addition to the input modality
258
+ if (
259
+ config.is_encoder_decoder
260
+ or (hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder"))
261
+ or "decoder_input_ids" in tf_dummy_inputs
262
+ ):
263
+ decoder_input_ids = np.asarray([[1], [1]], dtype=int) * (pt_model.config.decoder_start_token_id or 0)
264
+ pt_input.update({"decoder_input_ids": torch.tensor(decoder_input_ids)})
265
+ tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})
266
+
267
+ return pt_input, tf_input
268
+
269
+ def run(self):
270
+ # hub version 0.9.0 introduced the possibility of programmatically opening PRs with normal write tokens.
271
+ if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
272
+ raise ImportError(
273
+ "The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
274
+ " installation."
275
+ )
276
+ else:
277
+ from huggingface_hub import Repository, create_commit
278
+ from huggingface_hub._commit_api import CommitOperationAdd
279
+
280
+ # Fetch remote data
281
+ repo = Repository(local_dir=self._local_dir, clone_from=self._model_name)
282
+
283
+ # Load config and get the appropriate architecture -- the latter is needed to convert the head's weights
284
+ config = AutoConfig.from_pretrained(self._local_dir)
285
+ architectures = config.architectures
286
+ if self._override_model_class is not None:
287
+ if self._override_model_class.startswith("TF"):
288
+ architectures = [self._override_model_class[2:]]
289
+ else:
290
+ architectures = [self._override_model_class]
291
+ try:
292
+ pt_class = getattr(import_module("transformers"), architectures[0])
293
+ except AttributeError:
294
+ raise ValueError(f"Model class {self._override_model_class} not found in transformers.")
295
+ try:
296
+ tf_class = getattr(import_module("transformers"), "TF" + architectures[0])
297
+ except AttributeError:
298
+ raise ValueError(f"TF model class TF{self._override_model_class} not found in transformers.")
299
+ elif architectures is None: # No architecture defined -- use auto classes
300
+ pt_class = getattr(import_module("transformers"), "AutoModel")
301
+ tf_class = getattr(import_module("transformers"), "TFAutoModel")
302
+ self._logger.warning("No detected architecture, using AutoModel/TFAutoModel")
303
+ else: # Architecture defined -- use it
304
+ if len(architectures) > 1:
305
+ raise ValueError(f"More than one architecture was found, aborting. (architectures = {architectures})")
306
+ self._logger.warning(f"Detected architecture: {architectures[0]}")
307
+ pt_class = getattr(import_module("transformers"), architectures[0])
308
+ try:
309
+ tf_class = getattr(import_module("transformers"), "TF" + architectures[0])
310
+ except AttributeError:
311
+ raise AttributeError(f"The TensorFlow equivalent of {architectures[0]} doesn't exist in transformers.")
312
+
313
+ # Check the TF dummy inputs to see what keys we need in the forward pass
314
+ tf_from_pt_model = tf_class.from_config(config)
315
+ tf_dummy_inputs = tf_from_pt_model.dummy_inputs
316
+
317
+ del tf_from_pt_model # Try to keep only one model in memory at a time
318
+
319
+ # Load the model and get some basic inputs
320
+ pt_model = pt_class.from_pretrained(self._local_dir)
321
+ pt_model.eval()
322
+
323
+ pt_input, tf_input = self.get_inputs(pt_model, tf_dummy_inputs, config)
324
+
325
+ with torch.no_grad():
326
+ pt_outputs = pt_model(**pt_input, output_hidden_states=True)
327
+ del pt_model # will no longer be used, and may have a large memory footprint
328
+
329
+ tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
330
+ tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True, training=False)
331
+
332
+ # Confirms that cross loading PT weights into TF worked.
333
+ crossload_differences = self.find_pt_tf_differences(pt_outputs, tf_from_pt_outputs)
334
+ output_differences = {k: v for k, v in crossload_differences.items() if "hidden" not in k}
335
+ hidden_differences = {k: v for k, v in crossload_differences.items() if "hidden" in k}
336
+ if len(output_differences) == 0 and architectures is not None:
337
+ raise ValueError(
338
+ f"Something went wrong -- the config file has architectures ({architectures}), but no model head"
339
+ " output was found. All outputs start with 'hidden'"
340
+ )
341
+ max_crossload_output_diff = max(output_differences.values()) if output_differences else 0.0
342
+ max_crossload_hidden_diff = max(hidden_differences.values())
343
+ if max_crossload_output_diff > self._max_error or max_crossload_hidden_diff > self._max_error:
344
+ raise ValueError(
345
+ "The cross-loaded TensorFlow model has different outputs, something went wrong!\n"
346
+ + f"\nList of maximum output differences above the threshold ({self._max_error}):\n"
347
+ + "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > self._max_error])
348
+ + f"\n\nList of maximum hidden layer differences above the threshold ({self._max_error}):\n"
349
+ + "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_error])
350
+ )
351
+
352
+ # Save the weights in a TF format (if needed) and confirms that the results are still good
353
+ tf_weights_path = os.path.join(self._local_dir, TF2_WEIGHTS_NAME)
354
+ tf_weights_index_path = os.path.join(self._local_dir, TF2_WEIGHTS_INDEX_NAME)
355
+ if (not os.path.exists(tf_weights_path) and not os.path.exists(tf_weights_index_path)) or self._new_weights:
356
+ tf_from_pt_model.save_pretrained(self._local_dir)
357
+ del tf_from_pt_model # will no longer be used, and may have a large memory footprint
358
+
359
+ tf_model = tf_class.from_pretrained(self._local_dir)
360
+ tf_outputs = tf_model(**tf_input, output_hidden_states=True)
361
+
362
+ conversion_differences = self.find_pt_tf_differences(pt_outputs, tf_outputs)
363
+ output_differences = {k: v for k, v in conversion_differences.items() if "hidden" not in k}
364
+ hidden_differences = {k: v for k, v in conversion_differences.items() if "hidden" in k}
365
+ if len(output_differences) == 0 and architectures is not None:
366
+ raise ValueError(
367
+ f"Something went wrong -- the config file has architectures ({architectures}), but no model head"
368
+ " output was found. All outputs start with 'hidden'"
369
+ )
370
+ max_conversion_output_diff = max(output_differences.values()) if output_differences else 0.0
371
+ max_conversion_hidden_diff = max(hidden_differences.values())
372
+ if max_conversion_output_diff > self._max_error or max_conversion_hidden_diff > self._max_error:
373
+ raise ValueError(
374
+ "The converted TensorFlow model has different outputs, something went wrong!\n"
375
+ + f"\nList of maximum output differences above the threshold ({self._max_error}):\n"
376
+ + "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > self._max_error])
377
+ + f"\n\nList of maximum hidden layer differences above the threshold ({self._max_error}):\n"
378
+ + "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_error])
379
+ )
380
+
381
+ commit_message = "Update TF weights" if self._new_weights else "Add TF weights"
382
+ if self._push:
383
+ repo.git_add(auto_lfs_track=True)
384
+ repo.git_commit(commit_message)
385
+ repo.git_push(blocking=True) # this prints a progress bar with the upload
386
+ self._logger.warning(f"TF weights pushed into {self._model_name}")
387
+ elif not self._no_pr:
388
+ self._logger.warning("Uploading the weights into a new PR...")
389
+ commit_descrition = (
390
+ "Model converted by the [`transformers`' `pt_to_tf`"
391
+ " CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py). "
392
+ "All converted model outputs and hidden layers were validated against its PyTorch counterpart.\n\n"
393
+ f"Maximum crossload output difference={max_crossload_output_diff:.3e}; "
394
+ f"Maximum crossload hidden layer difference={max_crossload_hidden_diff:.3e};\n"
395
+ f"Maximum conversion output difference={max_conversion_output_diff:.3e}; "
396
+ f"Maximum conversion hidden layer difference={max_conversion_hidden_diff:.3e};\n"
397
+ )
398
+ if self._max_error > MAX_ERROR:
399
+ commit_descrition += (
400
+ f"\n\nCAUTION: The maximum admissible error was manually increased to {self._max_error}!"
401
+ )
402
+ if self._extra_commit_description:
403
+ commit_descrition += "\n\n" + self._extra_commit_description
404
+
405
+ # sharded model -> adds all related files (index and .h5 shards)
406
+ if os.path.exists(tf_weights_index_path):
407
+ operations = [
408
+ CommitOperationAdd(path_in_repo=TF2_WEIGHTS_INDEX_NAME, path_or_fileobj=tf_weights_index_path)
409
+ ]
410
+ for shard_path in tf.io.gfile.glob(self._local_dir + "/tf_model-*.h5"):
411
+ operations += [
412
+ CommitOperationAdd(path_in_repo=os.path.basename(shard_path), path_or_fileobj=shard_path)
413
+ ]
414
+ else:
415
+ operations = [CommitOperationAdd(path_in_repo=TF2_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)]
416
+
417
+ hub_pr_url = create_commit(
418
+ repo_id=self._model_name,
419
+ operations=operations,
420
+ commit_message=commit_message,
421
+ commit_description=commit_descrition,
422
+ repo_type="model",
423
+ create_pr=True,
424
+ ).pr_url
425
+ self._logger.warning(f"PR open in {hub_pr_url}")
transformers_4_35_0/commands/run.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from argparse import ArgumentParser
16
+
17
+ from ..pipelines import Pipeline, PipelineDataFormat, get_supported_tasks, pipeline
18
+ from ..utils import logging
19
+ from . import BaseTransformersCLICommand
20
+
21
+
22
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
23
+
24
+
25
+ def try_infer_format_from_ext(path: str):
26
+ if not path:
27
+ return "pipe"
28
+
29
+ for ext in PipelineDataFormat.SUPPORTED_FORMATS:
30
+ if path.endswith(ext):
31
+ return ext
32
+
33
+ raise Exception(
34
+ f"Unable to determine file format from file extension {path}. "
35
+ f"Please provide the format through --format {PipelineDataFormat.SUPPORTED_FORMATS}"
36
+ )
37
+
38
+
39
+ def run_command_factory(args):
40
+ nlp = pipeline(
41
+ task=args.task,
42
+ model=args.model if args.model else None,
43
+ config=args.config,
44
+ tokenizer=args.tokenizer,
45
+ device=args.device,
46
+ )
47
+ format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format
48
+ reader = PipelineDataFormat.from_str(
49
+ format=format,
50
+ output_path=args.output,
51
+ input_path=args.input,
52
+ column=args.column if args.column else nlp.default_input_names,
53
+ overwrite=args.overwrite,
54
+ )
55
+ return RunCommand(nlp, reader)
56
+
57
+
58
+ class RunCommand(BaseTransformersCLICommand):
59
+ def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
60
+ self._nlp = nlp
61
+ self._reader = reader
62
+
63
+ @staticmethod
64
+ def register_subcommand(parser: ArgumentParser):
65
+ run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
66
+ run_parser.add_argument("--task", choices=get_supported_tasks(), help="Task to run")
67
+ run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
68
+ run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
69
+ run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
70
+ run_parser.add_argument("--config", type=str, help="Name or path to the model's config to instantiate.")
71
+ run_parser.add_argument(
72
+ "--tokenizer", type=str, help="Name of the tokenizer to use. (default: same as the model name)"
73
+ )
74
+ run_parser.add_argument(
75
+ "--column",
76
+ type=str,
77
+ help="Name of the column to use as input. (For multi columns input as QA use column1,columns2)",
78
+ )
79
+ run_parser.add_argument(
80
+ "--format",
81
+ type=str,
82
+ default="infer",
83
+ choices=PipelineDataFormat.SUPPORTED_FORMATS,
84
+ help="Input format to read from",
85
+ )
86
+ run_parser.add_argument(
87
+ "--device",
88
+ type=int,
89
+ default=-1,
90
+ help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
91
+ )
92
+ run_parser.add_argument("--overwrite", action="store_true", help="Allow overwriting the output file.")
93
+ run_parser.set_defaults(func=run_command_factory)
94
+
95
+ def run(self):
96
+ nlp, outputs = self._nlp, []
97
+
98
+ for entry in self._reader:
99
+ output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry)
100
+ if isinstance(output, dict):
101
+ outputs.append(output)
102
+ else:
103
+ outputs += output
104
+
105
+ # Saving data
106
+ if self._nlp.binary_output:
107
+ binary_path = self._reader.save_binary(outputs)
108
+ logger.warning(f"Current pipeline requires output to be in binary format, saving at {binary_path}")
109
+ else:
110
+ self._reader.save(outputs)
transformers_4_35_0/commands/serving.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from argparse import ArgumentParser, Namespace
16
+ from typing import Any, List, Optional
17
+
18
+ from ..pipelines import Pipeline, get_supported_tasks, pipeline
19
+ from ..utils import logging
20
+ from . import BaseTransformersCLICommand
21
+
22
+
23
+ try:
24
+ from fastapi import Body, FastAPI, HTTPException
25
+ from fastapi.routing import APIRoute
26
+ from pydantic import BaseModel
27
+ from starlette.responses import JSONResponse
28
+ from uvicorn import run
29
+
30
+ _serve_dependencies_installed = True
31
+ except (ImportError, AttributeError):
32
+ BaseModel = object
33
+
34
+ def Body(*x, **y):
35
+ pass
36
+
37
+ _serve_dependencies_installed = False
38
+
39
+
40
+ logger = logging.get_logger("transformers-cli/serving")
41
+
42
+
43
+ def serve_command_factory(args: Namespace):
44
+ """
45
+ Factory function used to instantiate serving server from provided command line arguments.
46
+
47
+ Returns: ServeCommand
48
+ """
49
+ nlp = pipeline(
50
+ task=args.task,
51
+ model=args.model if args.model else None,
52
+ config=args.config,
53
+ tokenizer=args.tokenizer,
54
+ device=args.device,
55
+ )
56
+ return ServeCommand(nlp, args.host, args.port, args.workers)
57
+
58
+
59
+ class ServeModelInfoResult(BaseModel):
60
+ """
61
+ Expose model information
62
+ """
63
+
64
+ infos: dict
65
+
66
+
67
+ class ServeTokenizeResult(BaseModel):
68
+ """
69
+ Tokenize result model
70
+ """
71
+
72
+ tokens: List[str]
73
+ tokens_ids: Optional[List[int]]
74
+
75
+
76
+ class ServeDeTokenizeResult(BaseModel):
77
+ """
78
+ DeTokenize result model
79
+ """
80
+
81
+ text: str
82
+
83
+
84
+ class ServeForwardResult(BaseModel):
85
+ """
86
+ Forward result model
87
+ """
88
+
89
+ output: Any
90
+
91
+
92
+ class ServeCommand(BaseTransformersCLICommand):
93
+ @staticmethod
94
+ def register_subcommand(parser: ArgumentParser):
95
+ """
96
+ Register this command to argparse so it's available for the transformer-cli
97
+
98
+ Args:
99
+ parser: Root parser to register command-specific arguments
100
+ """
101
+ serve_parser = parser.add_parser(
102
+ "serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
103
+ )
104
+ serve_parser.add_argument(
105
+ "--task",
106
+ type=str,
107
+ choices=get_supported_tasks(),
108
+ help="The task to run the pipeline on",
109
+ )
110
+ serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
111
+ serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
112
+ serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
113
+ serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
114
+ serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
115
+ serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
116
+ serve_parser.add_argument(
117
+ "--device",
118
+ type=int,
119
+ default=-1,
120
+ help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
121
+ )
122
+ serve_parser.set_defaults(func=serve_command_factory)
123
+
124
+ def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
125
+ self._pipeline = pipeline
126
+
127
+ self.host = host
128
+ self.port = port
129
+ self.workers = workers
130
+
131
+ if not _serve_dependencies_installed:
132
+ raise RuntimeError(
133
+ "Using serve command requires FastAPI and uvicorn. "
134
+ 'Please install transformers with [serving]: pip install "transformers[serving]".'
135
+ "Or install FastAPI and uvicorn separately."
136
+ )
137
+ else:
138
+ logger.info(f"Serving model over {host}:{port}")
139
+ self._app = FastAPI(
140
+ routes=[
141
+ APIRoute(
142
+ "/",
143
+ self.model_info,
144
+ response_model=ServeModelInfoResult,
145
+ response_class=JSONResponse,
146
+ methods=["GET"],
147
+ ),
148
+ APIRoute(
149
+ "/tokenize",
150
+ self.tokenize,
151
+ response_model=ServeTokenizeResult,
152
+ response_class=JSONResponse,
153
+ methods=["POST"],
154
+ ),
155
+ APIRoute(
156
+ "/detokenize",
157
+ self.detokenize,
158
+ response_model=ServeDeTokenizeResult,
159
+ response_class=JSONResponse,
160
+ methods=["POST"],
161
+ ),
162
+ APIRoute(
163
+ "/forward",
164
+ self.forward,
165
+ response_model=ServeForwardResult,
166
+ response_class=JSONResponse,
167
+ methods=["POST"],
168
+ ),
169
+ ],
170
+ timeout=600,
171
+ )
172
+
173
+ def run(self):
174
+ run(self._app, host=self.host, port=self.port, workers=self.workers)
175
+
176
+ def model_info(self):
177
+ return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
178
+
179
+ def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
180
+ """
181
+ Tokenize the provided input and eventually returns corresponding tokens id: - **text_input**: String to
182
+ tokenize - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer
183
+ mapping.
184
+ """
185
+ try:
186
+ tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
187
+
188
+ if return_ids:
189
+ tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
190
+ return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
191
+ else:
192
+ return ServeTokenizeResult(tokens=tokens_txt)
193
+
194
+ except Exception as e:
195
+ raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
196
+
197
+ def detokenize(
198
+ self,
199
+ tokens_ids: List[int] = Body(None, embed=True),
200
+ skip_special_tokens: bool = Body(False, embed=True),
201
+ cleanup_tokenization_spaces: bool = Body(True, embed=True),
202
+ ):
203
+ """
204
+ Detokenize the provided tokens ids to readable text: - **tokens_ids**: List of tokens ids -
205
+ **skip_special_tokens**: Flag indicating to not try to decode special tokens - **cleanup_tokenization_spaces**:
206
+ Flag indicating to remove all leading/trailing spaces and intermediate ones.
207
+ """
208
+ try:
209
+ decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
210
+ return ServeDeTokenizeResult(model="", text=decoded_str)
211
+ except Exception as e:
212
+ raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
213
+
214
+ async def forward(self, inputs=Body(None, embed=True)):
215
+ """
216
+ **inputs**: **attention_mask**: **tokens_type_ids**:
217
+ """
218
+
219
+ # Check we don't have empty string
220
+ if len(inputs) == 0:
221
+ return ServeForwardResult(output=[], attention=[])
222
+
223
+ try:
224
+ # Forward through the model
225
+ output = self._pipeline(inputs)
226
+ return ServeForwardResult(output=output)
227
+ except Exception as e:
228
+ raise HTTPException(500, {"error": str(e)})
transformers_4_35_0/commands/train.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from argparse import ArgumentParser, Namespace
17
+
18
+ from ..data import SingleSentenceClassificationProcessor as Processor
19
+ from ..pipelines import TextClassificationPipeline
20
+ from ..utils import is_tf_available, is_torch_available, logging
21
+ from . import BaseTransformersCLICommand
22
+
23
+
24
+ if not is_tf_available() and not is_torch_available():
25
+ raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
26
+
27
+ # TF training parameters
28
+ USE_XLA = False
29
+ USE_AMP = False
30
+
31
+
32
+ def train_command_factory(args: Namespace):
33
+ """
34
+ Factory function used to instantiate training command from provided command line arguments.
35
+
36
+ Returns: TrainCommand
37
+ """
38
+ return TrainCommand(args)
39
+
40
+
41
+ class TrainCommand(BaseTransformersCLICommand):
42
+ @staticmethod
43
+ def register_subcommand(parser: ArgumentParser):
44
+ """
45
+ Register this command to argparse so it's available for the transformer-cli
46
+
47
+ Args:
48
+ parser: Root parser to register command-specific arguments
49
+ """
50
+ train_parser = parser.add_parser("train", help="CLI tool to train a model on a task.")
51
+
52
+ train_parser.add_argument(
53
+ "--train_data",
54
+ type=str,
55
+ required=True,
56
+ help="path to train (and optionally evaluation) dataset as a csv with tab separated labels and sentences.",
57
+ )
58
+ train_parser.add_argument(
59
+ "--column_label", type=int, default=0, help="Column of the dataset csv file with example labels."
60
+ )
61
+ train_parser.add_argument(
62
+ "--column_text", type=int, default=1, help="Column of the dataset csv file with example texts."
63
+ )
64
+ train_parser.add_argument(
65
+ "--column_id", type=int, default=2, help="Column of the dataset csv file with example ids."
66
+ )
67
+ train_parser.add_argument(
68
+ "--skip_first_row", action="store_true", help="Skip the first row of the csv file (headers)."
69
+ )
70
+
71
+ train_parser.add_argument("--validation_data", type=str, default="", help="path to validation dataset.")
72
+ train_parser.add_argument(
73
+ "--validation_split",
74
+ type=float,
75
+ default=0.1,
76
+ help="if validation dataset is not provided, fraction of train dataset to use as validation dataset.",
77
+ )
78
+
79
+ train_parser.add_argument("--output", type=str, default="./", help="path to saved the trained model.")
80
+
81
+ train_parser.add_argument(
82
+ "--task", type=str, default="text_classification", help="Task to train the model on."
83
+ )
84
+ train_parser.add_argument(
85
+ "--model", type=str, default="bert-base-uncased", help="Model's name or path to stored model."
86
+ )
87
+ train_parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training.")
88
+ train_parser.add_argument("--valid_batch_size", type=int, default=64, help="Batch size for validation.")
89
+ train_parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.")
90
+ train_parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon for Adam optimizer.")
91
+ train_parser.set_defaults(func=train_command_factory)
92
+
93
+ def __init__(self, args: Namespace):
94
+ self.logger = logging.get_logger("transformers-cli/training")
95
+
96
+ self.framework = "tf" if is_tf_available() else "torch"
97
+
98
+ os.makedirs(args.output, exist_ok=True)
99
+ self.output = args.output
100
+
101
+ self.column_label = args.column_label
102
+ self.column_text = args.column_text
103
+ self.column_id = args.column_id
104
+
105
+ self.logger.info(f"Loading {args.task} pipeline for {args.model}")
106
+ if args.task == "text_classification":
107
+ self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
108
+ elif args.task == "token_classification":
109
+ raise NotImplementedError
110
+ elif args.task == "question_answering":
111
+ raise NotImplementedError
112
+
113
+ self.logger.info(f"Loading dataset from {args.train_data}")
114
+ self.train_dataset = Processor.create_from_csv(
115
+ args.train_data,
116
+ column_label=args.column_label,
117
+ column_text=args.column_text,
118
+ column_id=args.column_id,
119
+ skip_first_row=args.skip_first_row,
120
+ )
121
+ self.valid_dataset = None
122
+ if args.validation_data:
123
+ self.logger.info(f"Loading validation dataset from {args.validation_data}")
124
+ self.valid_dataset = Processor.create_from_csv(
125
+ args.validation_data,
126
+ column_label=args.column_label,
127
+ column_text=args.column_text,
128
+ column_id=args.column_id,
129
+ skip_first_row=args.skip_first_row,
130
+ )
131
+
132
+ self.validation_split = args.validation_split
133
+ self.train_batch_size = args.train_batch_size
134
+ self.valid_batch_size = args.valid_batch_size
135
+ self.learning_rate = args.learning_rate
136
+ self.adam_epsilon = args.adam_epsilon
137
+
138
+ def run(self):
139
+ if self.framework == "tf":
140
+ return self.run_tf()
141
+ return self.run_torch()
142
+
143
+ def run_torch(self):
144
+ raise NotImplementedError
145
+
146
+ def run_tf(self):
147
+ self.pipeline.fit(
148
+ self.train_dataset,
149
+ validation_data=self.valid_dataset,
150
+ validation_split=self.validation_split,
151
+ learning_rate=self.learning_rate,
152
+ adam_epsilon=self.adam_epsilon,
153
+ train_batch_size=self.train_batch_size,
154
+ valid_batch_size=self.valid_batch_size,
155
+ )
156
+
157
+ # Save trained pipeline
158
+ self.pipeline.save_pretrained(self.output)
transformers_4_35_0/commands/transformers_cli.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .add_new_model import AddNewModelCommand
19
+ from .add_new_model_like import AddNewModelLikeCommand
20
+ from .convert import ConvertCommand
21
+ from .download import DownloadCommand
22
+ from .env import EnvironmentCommand
23
+ from .lfs import LfsCommands
24
+ from .pt_to_tf import PTtoTFCommand
25
+ from .run import RunCommand
26
+ from .serving import ServeCommand
27
+ from .user import UserCommands
28
+
29
+
30
+ def main():
31
+ parser = ArgumentParser("Transformers CLI tool", usage="transformers-cli <command> [<args>]")
32
+ commands_parser = parser.add_subparsers(help="transformers-cli command helpers")
33
+
34
+ # Register commands
35
+ ConvertCommand.register_subcommand(commands_parser)
36
+ DownloadCommand.register_subcommand(commands_parser)
37
+ EnvironmentCommand.register_subcommand(commands_parser)
38
+ RunCommand.register_subcommand(commands_parser)
39
+ ServeCommand.register_subcommand(commands_parser)
40
+ UserCommands.register_subcommand(commands_parser)
41
+ AddNewModelCommand.register_subcommand(commands_parser)
42
+ AddNewModelLikeCommand.register_subcommand(commands_parser)
43
+ LfsCommands.register_subcommand(commands_parser)
44
+ PTtoTFCommand.register_subcommand(commands_parser)
45
+
46
+ # Let's go
47
+ args = parser.parse_args()
48
+
49
+ if not hasattr(args, "func"):
50
+ parser.print_help()
51
+ exit(1)
52
+
53
+ # Run
54
+ service = args.func(args)
55
+ service.run()
56
+
57
+
58
+ if __name__ == "__main__":
59
+ main()
transformers_4_35_0/commands/user.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import subprocess
16
+ from argparse import ArgumentParser
17
+ from typing import List, Union
18
+
19
+ from huggingface_hub.hf_api import HfFolder, create_repo, whoami
20
+ from requests.exceptions import HTTPError
21
+
22
+ from . import BaseTransformersCLICommand
23
+
24
+
25
+ class UserCommands(BaseTransformersCLICommand):
26
+ @staticmethod
27
+ def register_subcommand(parser: ArgumentParser):
28
+ login_parser = parser.add_parser("login", help="Log in using the same credentials as on huggingface.co")
29
+ login_parser.set_defaults(func=lambda args: LoginCommand(args))
30
+ whoami_parser = parser.add_parser("whoami", help="Find out which huggingface.co account you are logged in as.")
31
+ whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
32
+ logout_parser = parser.add_parser("logout", help="Log out")
33
+ logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
34
+
35
+ # new system: git-based repo system
36
+ repo_parser = parser.add_parser(
37
+ "repo",
38
+ help="Deprecated: use `huggingface-cli` instead. Commands to interact with your huggingface.co repos.",
39
+ )
40
+ repo_subparsers = repo_parser.add_subparsers(
41
+ help="Deprecated: use `huggingface-cli` instead. huggingface.co repos related commands"
42
+ )
43
+ repo_create_parser = repo_subparsers.add_parser(
44
+ "create", help="Deprecated: use `huggingface-cli` instead. Create a new repo on huggingface.co"
45
+ )
46
+ repo_create_parser.add_argument(
47
+ "name",
48
+ type=str,
49
+ help="Name for your model's repo. Will be namespaced under your username to build the model id.",
50
+ )
51
+ repo_create_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
52
+ repo_create_parser.add_argument("-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt")
53
+ repo_create_parser.set_defaults(func=lambda args: RepoCreateCommand(args))
54
+
55
+
56
+ class ANSI:
57
+ """
58
+ Helper for en.wikipedia.org/wiki/ANSI_escape_code
59
+ """
60
+
61
+ _bold = "\u001b[1m"
62
+ _red = "\u001b[31m"
63
+ _gray = "\u001b[90m"
64
+ _reset = "\u001b[0m"
65
+
66
+ @classmethod
67
+ def bold(cls, s):
68
+ return f"{cls._bold}{s}{cls._reset}"
69
+
70
+ @classmethod
71
+ def red(cls, s):
72
+ return f"{cls._bold}{cls._red}{s}{cls._reset}"
73
+
74
+ @classmethod
75
+ def gray(cls, s):
76
+ return f"{cls._gray}{s}{cls._reset}"
77
+
78
+
79
+ def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
80
+ """
81
+ Inspired by:
82
+
83
+ - stackoverflow.com/a/8356620/593036
84
+ - stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
85
+ """
86
+ col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
87
+ row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
88
+ lines = []
89
+ lines.append(row_format.format(*headers))
90
+ lines.append(row_format.format(*["-" * w for w in col_widths]))
91
+ for row in rows:
92
+ lines.append(row_format.format(*row))
93
+ return "\n".join(lines)
94
+
95
+
96
+ class BaseUserCommand:
97
+ def __init__(self, args):
98
+ self.args = args
99
+
100
+
101
+ class LoginCommand(BaseUserCommand):
102
+ def run(self):
103
+ print(
104
+ ANSI.red(
105
+ "ERROR! `huggingface-cli login` uses an outdated login mechanism "
106
+ "that is not compatible with the Hugging Face Hub backend anymore. "
107
+ "Please use `huggingface-cli login instead."
108
+ )
109
+ )
110
+
111
+
112
+ class WhoamiCommand(BaseUserCommand):
113
+ def run(self):
114
+ print(
115
+ ANSI.red(
116
+ "WARNING! `transformers-cli whoami` is deprecated and will be removed in v5. Please use "
117
+ "`huggingface-cli whoami` instead."
118
+ )
119
+ )
120
+ token = HfFolder.get_token()
121
+ if token is None:
122
+ print("Not logged in")
123
+ exit()
124
+ try:
125
+ user, orgs = whoami(token)
126
+ print(user)
127
+ if orgs:
128
+ print(ANSI.bold("orgs: "), ",".join(orgs))
129
+ except HTTPError as e:
130
+ print(e)
131
+ print(ANSI.red(e.response.text))
132
+ exit(1)
133
+
134
+
135
+ class LogoutCommand(BaseUserCommand):
136
+ def run(self):
137
+ print(
138
+ ANSI.red(
139
+ "ERROR! `transformers-cli logout` uses an outdated logout mechanism "
140
+ "that is not compatible with the Hugging Face Hub backend anymore. "
141
+ "Please use `huggingface-cli logout instead."
142
+ )
143
+ )
144
+
145
+
146
+ class RepoCreateCommand(BaseUserCommand):
147
+ def run(self):
148
+ print(
149
+ ANSI.red(
150
+ "WARNING! Managing repositories through transformers-cli is deprecated. "
151
+ "Please use `huggingface-cli` instead."
152
+ )
153
+ )
154
+ token = HfFolder.get_token()
155
+ if token is None:
156
+ print("Not logged in")
157
+ exit(1)
158
+ try:
159
+ stdout = subprocess.check_output(["git", "--version"]).decode("utf-8")
160
+ print(ANSI.gray(stdout.strip()))
161
+ except FileNotFoundError:
162
+ print("Looks like you do not have git installed, please install.")
163
+
164
+ try:
165
+ stdout = subprocess.check_output(["git-lfs", "--version"]).decode("utf-8")
166
+ print(ANSI.gray(stdout.strip()))
167
+ except FileNotFoundError:
168
+ print(
169
+ ANSI.red(
170
+ "Looks like you do not have git-lfs installed, please install."
171
+ " You can install from https://git-lfs.github.com/."
172
+ " Then run `git lfs install` (you only have to do this once)."
173
+ )
174
+ )
175
+ print("")
176
+
177
+ user, _ = whoami(token)
178
+ namespace = self.args.organization if self.args.organization is not None else user
179
+ full_name = f"{namespace}/{self.args.name}"
180
+ print(f"You are about to create {ANSI.bold(full_name)}")
181
+
182
+ if not self.args.yes:
183
+ choice = input("Proceed? [Y/n] ").lower()
184
+ if not (choice == "" or choice == "y" or choice == "yes"):
185
+ print("Abort")
186
+ exit()
187
+ try:
188
+ url = create_repo(token, name=self.args.name, organization=self.args.organization)
189
+ except HTTPError as e:
190
+ print(e)
191
+ print(ANSI.red(e.response.text))
192
+ exit(1)
193
+ print("\nYour repo now lives at:")
194
+ print(f" {ANSI.bold(url)}")
195
+ print("\nYou can clone it locally with the command below, and commit/push as usual.")
196
+ print(f"\n git clone {url}")
197
+ print("")
transformers_4_35_0/configuration_utils.py ADDED
@@ -0,0 +1,1075 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ Configuration base class and utilities."""
17
+
18
+
19
+ import copy
20
+ import json
21
+ import os
22
+ import re
23
+ import warnings
24
+ from typing import Any, Dict, List, Optional, Tuple, Union
25
+
26
+ from packaging import version
27
+
28
+ from . import __version__
29
+ from .dynamic_module_utils import custom_object_save
30
+ from .utils import (
31
+ CONFIG_NAME,
32
+ PushToHubMixin,
33
+ add_model_info_to_auto_map,
34
+ cached_file,
35
+ copy_func,
36
+ download_url,
37
+ extract_commit_hash,
38
+ is_remote_url,
39
+ is_torch_available,
40
+ logging,
41
+ )
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
47
+
48
+
49
+ class PretrainedConfig(PushToHubMixin):
50
+ # no-format
51
+ r"""
52
+ Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
53
+ methods for loading/downloading/saving configurations.
54
+
55
+ <Tip>
56
+
57
+ A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to
58
+ initialize a model does **not** load the model weights. It only affects the model's configuration.
59
+
60
+ </Tip>
61
+
62
+ Class attributes (overridden by derived classes):
63
+
64
+ - **model_type** (`str`) -- An identifier for the model type, serialized into the JSON file, and used to recreate
65
+ the correct object in [`~transformers.AutoConfig`].
66
+ - **is_composition** (`bool`) -- Whether the config class is composed of multiple sub-configs. In this case the
67
+ config has to be initialized from two or more configs of type [`~transformers.PretrainedConfig`] like:
68
+ [`~transformers.EncoderDecoderConfig`] or [`~RagConfig`].
69
+ - **keys_to_ignore_at_inference** (`List[str]`) -- A list of keys to ignore by default when looking at dictionary
70
+ outputs of the model during inference.
71
+ - **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
72
+ naming of attributes.
73
+
74
+ Common attributes (present in all subclasses):
75
+
76
+ - **vocab_size** (`int`) -- The number of tokens in the vocabulary, which is also the first dimension of the
77
+ embeddings matrix (this attribute may be missing for models that don't have a text modality like ViT).
78
+ - **hidden_size** (`int`) -- The hidden size of the model.
79
+ - **num_attention_heads** (`int`) -- The number of attention heads used in the multi-head attention layers of the
80
+ model.
81
+ - **num_hidden_layers** (`int`) -- The number of blocks in the model.
82
+
83
+ Arg:
84
+ name_or_path (`str`, *optional*, defaults to `""`):
85
+ Store the string that was passed to [`PreTrainedModel.from_pretrained`] or
86
+ [`TFPreTrainedModel.from_pretrained`] as `pretrained_model_name_or_path` if the configuration was created
87
+ with such a method.
88
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
89
+ Whether or not the model should return all hidden-states.
90
+ output_attentions (`bool`, *optional*, defaults to `False`):
91
+ Whether or not the model should returns all attentions.
92
+ return_dict (`bool`, *optional*, defaults to `True`):
93
+ Whether or not the model should return a [`~transformers.utils.ModelOutput`] instead of a plain tuple.
94
+ is_encoder_decoder (`bool`, *optional*, defaults to `False`):
95
+ Whether the model is used as an encoder/decoder or not.
96
+ is_decoder (`bool`, *optional*, defaults to `False`):
97
+ Whether the model is used as decoder or not (in which case it's used as an encoder).
98
+ cross_attention_hidden_size** (`bool`, *optional*):
99
+ The hidden size of the cross-attention layer in case the model is used as a decoder in an encoder-decoder
100
+ setting and the cross-attention hidden dimension differs from `self.config.hidden_size`.
101
+ add_cross_attention (`bool`, *optional*, defaults to `False`):
102
+ Whether cross-attention layers should be added to the model. Note, this option is only relevant for models
103
+ that can be used as decoder models within the [`EncoderDecoderModel`] class, which consists of all models
104
+ in `AUTO_MODELS_FOR_CAUSAL_LM`.
105
+ tie_encoder_decoder (`bool`, *optional*, defaults to `False`):
106
+ Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder
107
+ and decoder model to have the exact same parameter names.
108
+ prune_heads (`Dict[int, List[int]]`, *optional*, defaults to `{}`):
109
+ Pruned heads of the model. The keys are the selected layer indices and the associated values, the list of
110
+ heads to prune in said layer.
111
+
112
+ For instance `{1: [0, 2], 2: [2, 3]}` will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
113
+ chunk_size_feed_forward (`int`, *optional*, defaults to `0`):
114
+ The chunk size of all feed forward layers in the residual attention blocks. A chunk size of `0` means that
115
+ the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes `n` <
116
+ sequence_length embeddings at a time. For more information on feed forward chunking, see [How does Feed
117
+ Forward Chunking work?](../glossary.html#feed-forward-chunking).
118
+
119
+ > Parameters for sequence generation
120
+
121
+ max_length (`int`, *optional*, defaults to 20):
122
+ Maximum length that will be used by default in the `generate` method of the model.
123
+ min_length (`int`, *optional*, defaults to 0):
124
+ Minimum length that will be used by default in the `generate` method of the model.
125
+ do_sample (`bool`, *optional*, defaults to `False`):
126
+ Flag that will be used by default in the `generate` method of the model. Whether or not to use sampling ;
127
+ use greedy decoding otherwise.
128
+ early_stopping (`bool`, *optional*, defaults to `False`):
129
+ Flag that will be used by default in the `generate` method of the model. Whether to stop the beam search
130
+ when at least `num_beams` sentences are finished per batch or not.
131
+ num_beams (`int`, *optional*, defaults to 1):
132
+ Number of beams for beam search that will be used by default in the `generate` method of the model. 1 means
133
+ no beam search.
134
+ num_beam_groups (`int`, *optional*, defaults to 1):
135
+ Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams
136
+ that will be used by default in the `generate` method of the model. 1 means no group beam search.
137
+ diversity_penalty (`float`, *optional*, defaults to 0.0):
138
+ Value to control diversity for group beam search. that will be used by default in the `generate` method of
139
+ the model. 0 means no diversity penalty. The higher the penalty, the more diverse are the outputs.
140
+ temperature (`float`, *optional*, defaults to 1.0):
141
+ The value used to module the next token probabilities that will be used by default in the `generate` method
142
+ of the model. Must be strictly positive.
143
+ top_k (`int`, *optional*, defaults to 50):
144
+ Number of highest probability vocabulary tokens to keep for top-k-filtering that will be used by default in
145
+ the `generate` method of the model.
146
+ top_p (`float`, *optional*, defaults to 1):
147
+ Value that will be used by default in the `generate` method of the model for `top_p`. If set to float < 1,
148
+ only the most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.
149
+ typical_p (`float`, *optional*, defaults to 1):
150
+ Local typicality measures how similar the conditional probability of predicting a target token next is to
151
+ the expected conditional probability of predicting a random token next, given the partial text already
152
+ generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that
153
+ add up to `typical_p` or higher are kept for generation. See [this
154
+ paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
155
+ repetition_penalty (`float`, *optional*, defaults to 1):
156
+ Parameter for repetition penalty that will be used by default in the `generate` method of the model. 1.0
157
+ means no penalty.
158
+ length_penalty (`float`, *optional*, defaults to 1):
159
+ Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
160
+ the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
161
+ likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
162
+ `length_penalty` < 0.0 encourages shorter sequences.
163
+ no_repeat_ngram_size (`int`, *optional*, defaults to 0) -- Value that will be used by default in the
164
+ `generate` method of the model for `no_repeat_ngram_size`. If set to int > 0, all ngrams of that size can
165
+ only occur once.
166
+ encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0) -- Value that will be used by
167
+ default in the `generate` method of the model for `encoder_no_repeat_ngram_size`. If set to int > 0, all
168
+ ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`.
169
+ bad_words_ids (`List[int]`, *optional*):
170
+ List of token ids that are not allowed to be generated that will be used by default in the `generate`
171
+ method of the model. In order to get the tokens of the words that should not appear in the generated text,
172
+ use `tokenizer.encode(bad_word, add_prefix_space=True)`.
173
+ num_return_sequences (`int`, *optional*, defaults to 1):
174
+ Number of independently computed returned sequences for each element in the batch that will be used by
175
+ default in the `generate` method of the model.
176
+ output_scores (`bool`, *optional*, defaults to `False`):
177
+ Whether the model should return the logits when used for generation.
178
+ return_dict_in_generate (`bool`, *optional*, defaults to `False`):
179
+ Whether the model should return a [`~transformers.utils.ModelOutput`] instead of a `torch.LongTensor`.
180
+ forced_bos_token_id (`int`, *optional*):
181
+ The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
182
+ multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
183
+ language token.
184
+ forced_eos_token_id (`int`, *optional*):
185
+ The id of the token to force as the last generated token when `max_length` is reached.
186
+ remove_invalid_values (`bool`, *optional*):
187
+ Whether to remove possible _nan_ and _inf_ outputs of the model to prevent the generation method to crash.
188
+ Note that using `remove_invalid_values` can slow down generation.
189
+
190
+ > Parameters for fine-tuning tasks
191
+
192
+ architectures (`List[str]`, *optional*):
193
+ Model architectures that can be used with the model pretrained weights.
194
+ finetuning_task (`str`, *optional*):
195
+ Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow
196
+ or PyTorch) checkpoint.
197
+ id2label (`Dict[int, str]`, *optional*):
198
+ A map from index (for instance prediction index, or target index) to label.
199
+ label2id (`Dict[str, int]`, *optional*): A map from label to index for the model.
200
+ num_labels (`int`, *optional*):
201
+ Number of labels to use in the last layer added to the model, typically for a classification task.
202
+ task_specific_params (`Dict[str, Any]`, *optional*):
203
+ Additional keyword arguments to store for the current task.
204
+ problem_type (`str`, *optional*):
205
+ Problem type for `XxxForSequenceClassification` models. Can be one of `"regression"`,
206
+ `"single_label_classification"` or `"multi_label_classification"`.
207
+
208
+ > Parameters linked to the tokenizer
209
+
210
+ tokenizer_class (`str`, *optional*):
211
+ The name of the associated tokenizer class to use (if none is set, will use the tokenizer associated to the
212
+ model by default).
213
+ prefix (`str`, *optional*):
214
+ A specific prompt that should be added at the beginning of each text before calling the model.
215
+ bos_token_id (`int`, *optional*): The id of the _beginning-of-stream_ token.
216
+ pad_token_id (`int`, *optional*): The id of the _padding_ token.
217
+ eos_token_id (`int`, *optional*): The id of the _end-of-stream_ token.
218
+ decoder_start_token_id (`int`, *optional*):
219
+ If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token.
220
+ sep_token_id (`int`, *optional*): The id of the _separation_ token.
221
+
222
+ > PyTorch specific parameters
223
+
224
+ torchscript (`bool`, *optional*, defaults to `False`):
225
+ Whether or not the model should be used with Torchscript.
226
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
227
+ Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
228
+ model has a output word embedding layer.
229
+ torch_dtype (`str`, *optional*):
230
+ The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
231
+ (which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
232
+ model is `float16`, ideally we want to load it back using the minimal amount of memory needed to load
233
+ `float16` weights. Since the config object is stored in plain text, this attribute contains just the
234
+ floating type string without the `torch.` prefix. For example, for `torch.float16` ``torch_dtype` is the
235
+ `"float16"` string.
236
+
237
+ This attribute is currently not being used during model loading time, but this may change in the future
238
+ versions. But we can already start preparing for the future by saving the dtype with save_pretrained.
239
+
240
+ > TensorFlow specific parameters
241
+
242
+ use_bfloat16 (`bool`, *optional*, defaults to `False`):
243
+ Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models).
244
+ tf_legacy_loss (`bool`, *optional*, defaults to `False`):
245
+ Whether the model should use legacy TensorFlow losses. Legacy losses have variable output shapes and may
246
+ not be XLA-compatible. This option is here for backward compatibility and will be removed in Transformers
247
+ v5.
248
+ """
249
+ model_type: str = ""
250
+ is_composition: bool = False
251
+ attribute_map: Dict[str, str] = {}
252
+ _auto_class: Optional[str] = None
253
+
254
+ def __setattr__(self, key, value):
255
+ if key in super().__getattribute__("attribute_map"):
256
+ key = super().__getattribute__("attribute_map")[key]
257
+ super().__setattr__(key, value)
258
+
259
+ def __getattribute__(self, key):
260
+ if key != "attribute_map" and key in super().__getattribute__("attribute_map"):
261
+ key = super().__getattribute__("attribute_map")[key]
262
+ return super().__getattribute__(key)
263
+
264
+ def __init__(self, **kwargs):
265
+ # Attributes with defaults
266
+ self.return_dict = kwargs.pop("return_dict", True)
267
+ self.output_hidden_states = kwargs.pop("output_hidden_states", False)
268
+ self.output_attentions = kwargs.pop("output_attentions", False)
269
+ self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
270
+ self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
271
+ self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
272
+ self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False) # Only used by TensorFlow models
273
+ self.pruned_heads = kwargs.pop("pruned_heads", {})
274
+ self.tie_word_embeddings = kwargs.pop(
275
+ "tie_word_embeddings", True
276
+ ) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
277
+
278
+ # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
279
+ self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
280
+ self.is_decoder = kwargs.pop("is_decoder", False)
281
+ self.cross_attention_hidden_size = kwargs.pop("cross_attention_hidden_size", None)
282
+ self.add_cross_attention = kwargs.pop("add_cross_attention", False)
283
+ self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
284
+
285
+ # Parameters for sequence generation
286
+ self.max_length = kwargs.pop("max_length", 20)
287
+ self.min_length = kwargs.pop("min_length", 0)
288
+ self.do_sample = kwargs.pop("do_sample", False)
289
+ self.early_stopping = kwargs.pop("early_stopping", False)
290
+ self.num_beams = kwargs.pop("num_beams", 1)
291
+ self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
292
+ self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
293
+ self.temperature = kwargs.pop("temperature", 1.0)
294
+ self.top_k = kwargs.pop("top_k", 50)
295
+ self.top_p = kwargs.pop("top_p", 1.0)
296
+ self.typical_p = kwargs.pop("typical_p", 1.0)
297
+ self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
298
+ self.length_penalty = kwargs.pop("length_penalty", 1.0)
299
+ self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
300
+ self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
301
+ self.bad_words_ids = kwargs.pop("bad_words_ids", None)
302
+ self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
303
+ self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
304
+ self.output_scores = kwargs.pop("output_scores", False)
305
+ self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
306
+ self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
307
+ self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
308
+ self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
309
+ self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
310
+ self.suppress_tokens = kwargs.pop("suppress_tokens", None)
311
+ self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
312
+
313
+ # Fine-tuning task arguments
314
+ self.architectures = kwargs.pop("architectures", None)
315
+ self.finetuning_task = kwargs.pop("finetuning_task", None)
316
+ self.id2label = kwargs.pop("id2label", None)
317
+ self.label2id = kwargs.pop("label2id", None)
318
+ if self.label2id is not None and not isinstance(self.label2id, dict):
319
+ raise ValueError("Argument label2id should be a dictionary.")
320
+ if self.id2label is not None:
321
+ if not isinstance(self.id2label, dict):
322
+ raise ValueError("Argument id2label should be a dictionary.")
323
+ num_labels = kwargs.pop("num_labels", None)
324
+ if num_labels is not None and len(self.id2label) != num_labels:
325
+ logger.warning(
326
+ f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
327
+ f"{self.id2label}. The number of labels wil be overwritten to {self.num_labels}."
328
+ )
329
+ self.id2label = {int(key): value for key, value in self.id2label.items()}
330
+ # Keys are always strings in JSON so convert ids to int here.
331
+ else:
332
+ self.num_labels = kwargs.pop("num_labels", 2)
333
+
334
+ if self.torch_dtype is not None and isinstance(self.torch_dtype, str):
335
+ # we will start using self.torch_dtype in v5, but to be consistent with
336
+ # from_pretrained's torch_dtype arg convert it to an actual torch.dtype object
337
+ if is_torch_available():
338
+ import torch
339
+
340
+ self.torch_dtype = getattr(torch, self.torch_dtype)
341
+
342
+ # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
343
+ self.tokenizer_class = kwargs.pop("tokenizer_class", None)
344
+ self.prefix = kwargs.pop("prefix", None)
345
+ self.bos_token_id = kwargs.pop("bos_token_id", None)
346
+ self.pad_token_id = kwargs.pop("pad_token_id", None)
347
+ self.eos_token_id = kwargs.pop("eos_token_id", None)
348
+ self.sep_token_id = kwargs.pop("sep_token_id", None)
349
+
350
+ self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
351
+
352
+ # task specific arguments
353
+ self.task_specific_params = kwargs.pop("task_specific_params", None)
354
+
355
+ # regression / multi-label classification
356
+ self.problem_type = kwargs.pop("problem_type", None)
357
+ allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification")
358
+ if self.problem_type is not None and self.problem_type not in allowed_problem_types:
359
+ raise ValueError(
360
+ f"The config parameter `problem_type` was not understood: received {self.problem_type} "
361
+ "but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
362
+ )
363
+
364
+ # TPU arguments
365
+ if kwargs.pop("xla_device", None) is not None:
366
+ logger.warning(
367
+ "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can "
368
+ "safely remove it from your `config.json` file."
369
+ )
370
+
371
+ # Name or path to the pretrained checkpoint
372
+ self._name_or_path = str(kwargs.pop("name_or_path", ""))
373
+ # Config hash
374
+ self._commit_hash = kwargs.pop("_commit_hash", None)
375
+
376
+ # Drop the transformers version info
377
+ self.transformers_version = kwargs.pop("transformers_version", None)
378
+
379
+ # Deal with gradient checkpointing
380
+ if kwargs.get("gradient_checkpointing", False):
381
+ warnings.warn(
382
+ "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
383
+ "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the "
384
+ "`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`."
385
+ )
386
+
387
+ # Additional attributes without default values
388
+ for key, value in kwargs.items():
389
+ try:
390
+ setattr(self, key, value)
391
+ except AttributeError as err:
392
+ logger.error(f"Can't set {key} with value {value} for {self}")
393
+ raise err
394
+
395
+ @property
396
+ def name_or_path(self) -> str:
397
+ return getattr(self, "_name_or_path", None)
398
+
399
+ @name_or_path.setter
400
+ def name_or_path(self, value):
401
+ self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding)
402
+
403
+ @property
404
+ def use_return_dict(self) -> bool:
405
+ """
406
+ `bool`: Whether or not return [`~utils.ModelOutput`] instead of tuples.
407
+ """
408
+ # If torchscript is set, force `return_dict=False` to avoid jit errors
409
+ return self.return_dict and not self.torchscript
410
+
411
+ @property
412
+ def num_labels(self) -> int:
413
+ """
414
+ `int`: The number of labels for classification models.
415
+ """
416
+ return len(self.id2label)
417
+
418
+ @num_labels.setter
419
+ def num_labels(self, num_labels: int):
420
+ if not hasattr(self, "id2label") or self.id2label is None or len(self.id2label) != num_labels:
421
+ self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
422
+ self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
423
+
424
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
425
+ """
426
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
427
+ [`~PretrainedConfig.from_pretrained`] class method.
428
+
429
+ Args:
430
+ save_directory (`str` or `os.PathLike`):
431
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
432
+ push_to_hub (`bool`, *optional*, defaults to `False`):
433
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
434
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
435
+ namespace).
436
+ kwargs (`Dict[str, Any]`, *optional*):
437
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
438
+ """
439
+ self._set_token_in_kwargs(kwargs)
440
+
441
+ if os.path.isfile(save_directory):
442
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
443
+
444
+ os.makedirs(save_directory, exist_ok=True)
445
+
446
+ if push_to_hub:
447
+ commit_message = kwargs.pop("commit_message", None)
448
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
449
+ repo_id = self._create_repo(repo_id, **kwargs)
450
+ files_timestamps = self._get_files_timestamps(save_directory)
451
+
452
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
453
+ # loaded from the Hub.
454
+ if self._auto_class is not None:
455
+ custom_object_save(self, save_directory, config=self)
456
+
457
+ # If we save using the predefined names, we can load using `from_pretrained`
458
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
459
+
460
+ self.to_json_file(output_config_file, use_diff=True)
461
+ logger.info(f"Configuration saved in {output_config_file}")
462
+
463
+ if push_to_hub:
464
+ self._upload_modified_files(
465
+ save_directory,
466
+ repo_id,
467
+ files_timestamps,
468
+ commit_message=commit_message,
469
+ token=kwargs.get("token"),
470
+ )
471
+
472
+ @staticmethod
473
+ def _set_token_in_kwargs(kwargs, token=None):
474
+ """Temporary method to deal with `token` and `use_auth_token`.
475
+
476
+ This method is to avoid apply the same changes in all model config classes that overwrite `from_pretrained`.
477
+
478
+ Need to clean up `use_auth_token` in a follow PR.
479
+ """
480
+ # Some model config classes like CLIP define their own `from_pretrained` without the new argument `token` yet.
481
+ if token is None:
482
+ token = kwargs.pop("token", None)
483
+ use_auth_token = kwargs.pop("use_auth_token", None)
484
+
485
+ if use_auth_token is not None:
486
+ warnings.warn(
487
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
488
+ )
489
+ if token is not None:
490
+ raise ValueError(
491
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
492
+ )
493
+ token = use_auth_token
494
+
495
+ if token is not None:
496
+ kwargs["token"] = token
497
+
498
+ @classmethod
499
+ def from_pretrained(
500
+ cls,
501
+ pretrained_model_name_or_path: Union[str, os.PathLike],
502
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
503
+ force_download: bool = False,
504
+ local_files_only: bool = False,
505
+ token: Optional[Union[str, bool]] = None,
506
+ revision: str = "main",
507
+ **kwargs,
508
+ ) -> "PretrainedConfig":
509
+ r"""
510
+ Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration.
511
+
512
+ Args:
513
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
514
+ This can be either:
515
+
516
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
517
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
518
+ namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
519
+ - a path to a *directory* containing a configuration file saved using the
520
+ [`~PretrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
521
+ - a path or url to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`.
522
+ cache_dir (`str` or `os.PathLike`, *optional*):
523
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
524
+ standard cache should not be used.
525
+ force_download (`bool`, *optional*, defaults to `False`):
526
+ Whether or not to force to (re-)download the configuration files and override the cached versions if
527
+ they exist.
528
+ resume_download (`bool`, *optional*, defaults to `False`):
529
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file
530
+ exists.
531
+ proxies (`Dict[str, str]`, *optional*):
532
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
533
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
534
+ token (`str` or `bool`, *optional*):
535
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
536
+ the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
537
+ revision (`str`, *optional*, defaults to `"main"`):
538
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
539
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
540
+ identifier allowed by git.
541
+
542
+ <Tip>
543
+
544
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>".
545
+
546
+ </Tip>
547
+
548
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
549
+ If `False`, then this function returns just the final configuration object.
550
+
551
+ If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
552
+ dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
553
+ part of `kwargs` which has not been used to update `config` and is otherwise ignored.
554
+ subfolder (`str`, *optional*, defaults to `""`):
555
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
556
+ specify the folder name here.
557
+ kwargs (`Dict[str, Any]`, *optional*):
558
+ The values in kwargs of any keys which are configuration attributes will be used to override the loaded
559
+ values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
560
+ by the `return_unused_kwargs` keyword parameter.
561
+
562
+ Returns:
563
+ [`PretrainedConfig`]: The configuration object instantiated from this pretrained model.
564
+
565
+ Examples:
566
+
567
+ ```python
568
+ # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
569
+ # derived class: BertConfig
570
+ config = BertConfig.from_pretrained(
571
+ "bert-base-uncased"
572
+ ) # Download configuration from huggingface.co and cache.
573
+ config = BertConfig.from_pretrained(
574
+ "./test/saved_model/"
575
+ ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
576
+ config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
577
+ config = BertConfig.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
578
+ assert config.output_attentions == True
579
+ config, unused_kwargs = BertConfig.from_pretrained(
580
+ "bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
581
+ )
582
+ assert config.output_attentions == True
583
+ assert unused_kwargs == {"foo": False}
584
+ ```"""
585
+ kwargs["cache_dir"] = cache_dir
586
+ kwargs["force_download"] = force_download
587
+ kwargs["local_files_only"] = local_files_only
588
+ kwargs["revision"] = revision
589
+
590
+ cls._set_token_in_kwargs(kwargs, token)
591
+
592
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
593
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
594
+ logger.warning(
595
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
596
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
597
+ )
598
+
599
+ return cls.from_dict(config_dict, **kwargs)
600
+
601
+ @classmethod
602
+ def get_config_dict(
603
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
604
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
605
+ """
606
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
607
+ [`PretrainedConfig`] using `from_dict`.
608
+
609
+ Parameters:
610
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
611
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
612
+
613
+ Returns:
614
+ `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
615
+
616
+ """
617
+ cls._set_token_in_kwargs(kwargs)
618
+
619
+ original_kwargs = copy.deepcopy(kwargs)
620
+ # Get config dict associated with the base config file
621
+ config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
622
+ if "_commit_hash" in config_dict:
623
+ original_kwargs["_commit_hash"] = config_dict["_commit_hash"]
624
+
625
+ # That config file may point us toward another config file to use.
626
+ if "configuration_files" in config_dict:
627
+ configuration_file = get_configuration_file(config_dict["configuration_files"])
628
+ config_dict, kwargs = cls._get_config_dict(
629
+ pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs
630
+ )
631
+
632
+ return config_dict, kwargs
633
+
634
+ @classmethod
635
+ def _get_config_dict(
636
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
637
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
638
+ cache_dir = kwargs.pop("cache_dir", None)
639
+ force_download = kwargs.pop("force_download", False)
640
+ resume_download = kwargs.pop("resume_download", False)
641
+ proxies = kwargs.pop("proxies", None)
642
+ token = kwargs.pop("token", None)
643
+ local_files_only = kwargs.pop("local_files_only", False)
644
+ revision = kwargs.pop("revision", None)
645
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
646
+ subfolder = kwargs.pop("subfolder", "")
647
+ from_pipeline = kwargs.pop("_from_pipeline", None)
648
+ from_auto_class = kwargs.pop("_from_auto", False)
649
+ commit_hash = kwargs.pop("_commit_hash", None)
650
+
651
+ if trust_remote_code is True:
652
+ logger.warning(
653
+ "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
654
+ " ignored."
655
+ )
656
+
657
+ user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
658
+ if from_pipeline is not None:
659
+ user_agent["using_pipeline"] = from_pipeline
660
+
661
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
662
+
663
+ is_local = os.path.isdir(pretrained_model_name_or_path)
664
+ if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
665
+ # Special case when pretrained_model_name_or_path is a local file
666
+ resolved_config_file = pretrained_model_name_or_path
667
+ is_local = True
668
+ elif is_remote_url(pretrained_model_name_or_path):
669
+ configuration_file = pretrained_model_name_or_path
670
+ resolved_config_file = download_url(pretrained_model_name_or_path)
671
+ else:
672
+ configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
673
+
674
+ try:
675
+ # Load from local folder or from cache or download from model Hub and cache
676
+ resolved_config_file = cached_file(
677
+ pretrained_model_name_or_path,
678
+ configuration_file,
679
+ cache_dir=cache_dir,
680
+ force_download=force_download,
681
+ proxies=proxies,
682
+ resume_download=resume_download,
683
+ local_files_only=local_files_only,
684
+ token=token,
685
+ user_agent=user_agent,
686
+ revision=revision,
687
+ subfolder=subfolder,
688
+ _commit_hash=commit_hash,
689
+ )
690
+ commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
691
+ except EnvironmentError:
692
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
693
+ # the original exception.
694
+ raise
695
+ except Exception:
696
+ # For any other exception, we throw a generic error.
697
+ raise EnvironmentError(
698
+ f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
699
+ " from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
700
+ f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
701
+ f" containing a {configuration_file} file"
702
+ )
703
+
704
+ try:
705
+ # Load config dict
706
+ config_dict = cls._dict_from_json_file(resolved_config_file)
707
+ config_dict["_commit_hash"] = commit_hash
708
+ except (json.JSONDecodeError, UnicodeDecodeError):
709
+ raise EnvironmentError(
710
+ f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
711
+ )
712
+
713
+ if is_local:
714
+ logger.info(f"loading configuration file {resolved_config_file}")
715
+ else:
716
+ logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
717
+
718
+ if "auto_map" in config_dict and not is_local:
719
+ config_dict["auto_map"] = add_model_info_to_auto_map(
720
+ config_dict["auto_map"], pretrained_model_name_or_path
721
+ )
722
+ return config_dict, kwargs
723
+
724
+ @classmethod
725
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
726
+ """
727
+ Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
728
+
729
+ Args:
730
+ config_dict (`Dict[str, Any]`):
731
+ Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
732
+ retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
733
+ kwargs (`Dict[str, Any]`):
734
+ Additional parameters from which to initialize the configuration object.
735
+
736
+ Returns:
737
+ [`PretrainedConfig`]: The configuration object instantiated from those parameters.
738
+ """
739
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
740
+ # Those arguments may be passed along for our internal telemetry.
741
+ # We remove them so they don't appear in `return_unused_kwargs`.
742
+ kwargs.pop("_from_auto", None)
743
+ kwargs.pop("_from_pipeline", None)
744
+ # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
745
+ if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
746
+ kwargs["_commit_hash"] = config_dict["_commit_hash"]
747
+
748
+ config = cls(**config_dict)
749
+
750
+ if hasattr(config, "pruned_heads"):
751
+ config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
752
+
753
+ # Update config with kwargs if needed
754
+ if "num_labels" in kwargs and "id2label" in kwargs:
755
+ num_labels = kwargs["num_labels"]
756
+ id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
757
+ if len(id2label) != num_labels:
758
+ raise ValueError(
759
+ f"You passed along `num_labels={num_labels }` with an incompatible id to label map: "
760
+ f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
761
+ "one of them."
762
+ )
763
+ to_remove = []
764
+ for key, value in kwargs.items():
765
+ if hasattr(config, key):
766
+ current_attr = getattr(config, key)
767
+ # To authorize passing a custom subconfig as kwarg in models that have nested configs.
768
+ if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict):
769
+ value = current_attr.__class__(**value)
770
+ setattr(config, key, value)
771
+ if key != "torch_dtype":
772
+ to_remove.append(key)
773
+ for key in to_remove:
774
+ kwargs.pop(key, None)
775
+
776
+ logger.info(f"Model config {config}")
777
+ if return_unused_kwargs:
778
+ return config, kwargs
779
+ else:
780
+ return config
781
+
782
+ @classmethod
783
+ def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig":
784
+ """
785
+ Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters.
786
+
787
+ Args:
788
+ json_file (`str` or `os.PathLike`):
789
+ Path to the JSON file containing the parameters.
790
+
791
+ Returns:
792
+ [`PretrainedConfig`]: The configuration object instantiated from that JSON file.
793
+
794
+ """
795
+ config_dict = cls._dict_from_json_file(json_file)
796
+ return cls(**config_dict)
797
+
798
+ @classmethod
799
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
800
+ with open(json_file, "r", encoding="utf-8") as reader:
801
+ text = reader.read()
802
+ return json.loads(text)
803
+
804
+ def __eq__(self, other):
805
+ return isinstance(other, PretrainedConfig) and (self.__dict__ == other.__dict__)
806
+
807
+ def __repr__(self):
808
+ return f"{self.__class__.__name__} {self.to_json_string()}"
809
+
810
+ def to_diff_dict(self) -> Dict[str, Any]:
811
+ """
812
+ Removes all attributes from config which correspond to the default config attributes for better readability and
813
+ serializes to a Python dictionary.
814
+
815
+ Returns:
816
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
817
+ """
818
+ config_dict = self.to_dict()
819
+
820
+ # get the default config dict
821
+ default_config_dict = PretrainedConfig().to_dict()
822
+
823
+ # get class specific config dict
824
+ class_config_dict = self.__class__().to_dict() if not self.is_composition else {}
825
+
826
+ serializable_config_dict = {}
827
+
828
+ # only serialize values that differ from the default config
829
+ for key, value in config_dict.items():
830
+ if (
831
+ isinstance(getattr(self, key, None), PretrainedConfig)
832
+ and key in class_config_dict
833
+ and isinstance(class_config_dict[key], dict)
834
+ ):
835
+ # For nested configs we need to clean the diff recursively
836
+ diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None))
837
+ if "model_type" in value:
838
+ # Needs to be set even if it's not in the diff
839
+ diff["model_type"] = value["model_type"]
840
+ if len(diff) > 0:
841
+ serializable_config_dict[key] = diff
842
+ elif (
843
+ key not in default_config_dict
844
+ or key == "transformers_version"
845
+ or value != default_config_dict[key]
846
+ or (key in class_config_dict and value != class_config_dict[key])
847
+ ):
848
+ serializable_config_dict[key] = value
849
+
850
+ if hasattr(self, "quantization_config"):
851
+ serializable_config_dict["quantization_config"] = (
852
+ self.quantization_config.to_dict()
853
+ if not isinstance(self.quantization_config, dict)
854
+ else self.quantization_config
855
+ )
856
+
857
+ self.dict_torch_dtype_to_str(serializable_config_dict)
858
+
859
+ if "_flash_attn_2_enabled" in serializable_config_dict:
860
+ del serializable_config_dict["_flash_attn_2_enabled"]
861
+
862
+ return serializable_config_dict
863
+
864
+ def to_dict(self) -> Dict[str, Any]:
865
+ """
866
+ Serializes this instance to a Python dictionary.
867
+
868
+ Returns:
869
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
870
+ """
871
+ output = copy.deepcopy(self.__dict__)
872
+ if hasattr(self.__class__, "model_type"):
873
+ output["model_type"] = self.__class__.model_type
874
+ if "_auto_class" in output:
875
+ del output["_auto_class"]
876
+ if "_commit_hash" in output:
877
+ del output["_commit_hash"]
878
+ if "_flash_attn_2_enabled" in output:
879
+ del output["_flash_attn_2_enabled"]
880
+
881
+ # Transformers version when serializing the model
882
+ output["transformers_version"] = __version__
883
+
884
+ for key, value in output.items():
885
+ # Deal with nested configs like CLIP
886
+ if isinstance(value, PretrainedConfig):
887
+ value = value.to_dict()
888
+ del value["transformers_version"]
889
+
890
+ output[key] = value
891
+
892
+ if hasattr(self, "quantization_config"):
893
+ output["quantization_config"] = (
894
+ self.quantization_config.to_dict()
895
+ if not isinstance(self.quantization_config, dict)
896
+ else self.quantization_config
897
+ )
898
+
899
+ self.dict_torch_dtype_to_str(output)
900
+
901
+ return output
902
+
903
+ def to_json_string(self, use_diff: bool = True) -> str:
904
+ """
905
+ Serializes this instance to a JSON string.
906
+
907
+ Args:
908
+ use_diff (`bool`, *optional*, defaults to `True`):
909
+ If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
910
+ is serialized to JSON string.
911
+
912
+ Returns:
913
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
914
+ """
915
+ if use_diff is True:
916
+ config_dict = self.to_diff_dict()
917
+ else:
918
+ config_dict = self.to_dict()
919
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
920
+
921
+ def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
922
+ """
923
+ Save this instance to a JSON file.
924
+
925
+ Args:
926
+ json_file_path (`str` or `os.PathLike`):
927
+ Path to the JSON file in which this configuration instance's parameters will be saved.
928
+ use_diff (`bool`, *optional*, defaults to `True`):
929
+ If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
930
+ is serialized to JSON file.
931
+ """
932
+ with open(json_file_path, "w", encoding="utf-8") as writer:
933
+ writer.write(self.to_json_string(use_diff=use_diff))
934
+
935
+ def update(self, config_dict: Dict[str, Any]):
936
+ """
937
+ Updates attributes of this class with attributes from `config_dict`.
938
+
939
+ Args:
940
+ config_dict (`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
941
+ """
942
+ for key, value in config_dict.items():
943
+ setattr(self, key, value)
944
+
945
+ def update_from_string(self, update_str: str):
946
+ """
947
+ Updates attributes of this class with attributes from `update_str`.
948
+
949
+ The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
950
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
951
+
952
+ The keys to change have to already exist in the config object.
953
+
954
+ Args:
955
+ update_str (`str`): String with attributes that should be updated for this class.
956
+
957
+ """
958
+
959
+ d = dict(x.split("=") for x in update_str.split(","))
960
+ for k, v in d.items():
961
+ if not hasattr(self, k):
962
+ raise ValueError(f"key {k} isn't in the original config dict")
963
+
964
+ old_v = getattr(self, k)
965
+ if isinstance(old_v, bool):
966
+ if v.lower() in ["true", "1", "y", "yes"]:
967
+ v = True
968
+ elif v.lower() in ["false", "0", "n", "no"]:
969
+ v = False
970
+ else:
971
+ raise ValueError(f"can't derive true or false from {v} (key {k})")
972
+ elif isinstance(old_v, int):
973
+ v = int(v)
974
+ elif isinstance(old_v, float):
975
+ v = float(v)
976
+ elif not isinstance(old_v, str):
977
+ raise ValueError(
978
+ f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
979
+ )
980
+
981
+ setattr(self, k, v)
982
+
983
+ def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
984
+ """
985
+ Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
986
+ converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
987
+ string, which can then be stored in the json format.
988
+ """
989
+ if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
990
+ d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
991
+ for value in d.values():
992
+ if isinstance(value, dict):
993
+ self.dict_torch_dtype_to_str(value)
994
+
995
+ @classmethod
996
+ def register_for_auto_class(cls, auto_class="AutoConfig"):
997
+ """
998
+ Register this class with a given auto class. This should only be used for custom configurations as the ones in
999
+ the library are already mapped with `AutoConfig`.
1000
+
1001
+ <Tip warning={true}>
1002
+
1003
+ This API is experimental and may have some slight breaking changes in the next releases.
1004
+
1005
+ </Tip>
1006
+
1007
+ Args:
1008
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoConfig"`):
1009
+ The auto class to register this new configuration with.
1010
+ """
1011
+ if not isinstance(auto_class, str):
1012
+ auto_class = auto_class.__name__
1013
+
1014
+ import transformers.models.auto as auto_module
1015
+
1016
+ if not hasattr(auto_module, auto_class):
1017
+ raise ValueError(f"{auto_class} is not a valid auto class.")
1018
+
1019
+ cls._auto_class = auto_class
1020
+
1021
+
1022
+ def get_configuration_file(configuration_files: List[str]) -> str:
1023
+ """
1024
+ Get the configuration file to use for this version of transformers.
1025
+
1026
+ Args:
1027
+ configuration_files (`List[str]`): The list of available configuration files.
1028
+
1029
+ Returns:
1030
+ `str`: The configuration file to use.
1031
+ """
1032
+ configuration_files_map = {}
1033
+ for file_name in configuration_files:
1034
+ search = _re_configuration_file.search(file_name)
1035
+ if search is not None:
1036
+ v = search.groups()[0]
1037
+ configuration_files_map[v] = file_name
1038
+ available_versions = sorted(configuration_files_map.keys())
1039
+
1040
+ # Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
1041
+ configuration_file = CONFIG_NAME
1042
+ transformers_version = version.parse(__version__)
1043
+ for v in available_versions:
1044
+ if version.parse(v) <= transformers_version:
1045
+ configuration_file = configuration_files_map[v]
1046
+ else:
1047
+ # No point going further since the versions are sorted.
1048
+ break
1049
+
1050
+ return configuration_file
1051
+
1052
+
1053
+ def recursive_diff_dict(dict_a, dict_b, config_obj=None):
1054
+ """
1055
+ Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
1056
+ values from `dict_a` that are different from values in `dict_b`.
1057
+ """
1058
+ diff = {}
1059
+ default = config_obj.__class__().to_dict() if config_obj is not None else {}
1060
+ for key, value in dict_a.items():
1061
+ obj_value = getattr(config_obj, str(key), None)
1062
+ if isinstance(obj_value, PretrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
1063
+ diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
1064
+ if len(diff_value) > 0:
1065
+ diff[key] = diff_value
1066
+ elif key not in dict_b or value != dict_b[key] or key not in default or value != default[key]:
1067
+ diff[key] = value
1068
+ return diff
1069
+
1070
+
1071
+ PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
1072
+ if PretrainedConfig.push_to_hub.__doc__ is not None:
1073
+ PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
1074
+ object="config", object_class="AutoConfig", object_files="configuration file"
1075
+ )
transformers_4_35_0/convert_graph_to_onnx.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+ from argparse import ArgumentParser
17
+ from os import listdir, makedirs
18
+ from pathlib import Path
19
+ from typing import Dict, List, Optional, Tuple
20
+
21
+ from packaging.version import Version, parse
22
+
23
+ from transformers.pipelines import Pipeline, pipeline
24
+ from transformers.tokenization_utils import BatchEncoding
25
+ from transformers.utils import ModelOutput, is_tf_available, is_torch_available
26
+
27
+
28
+ # This is the minimal required version to
29
+ # support some ONNX Runtime features
30
+ ORT_QUANTIZE_MINIMUM_VERSION = parse("1.4.0")
31
+
32
+
33
+ SUPPORTED_PIPELINES = [
34
+ "feature-extraction",
35
+ "ner",
36
+ "sentiment-analysis",
37
+ "fill-mask",
38
+ "question-answering",
39
+ "text-generation",
40
+ "translation_en_to_fr",
41
+ "translation_en_to_de",
42
+ "translation_en_to_ro",
43
+ ]
44
+
45
+
46
+ class OnnxConverterArgumentParser(ArgumentParser):
47
+ """
48
+ Wraps all the script arguments supported to export transformers models to ONNX IR
49
+ """
50
+
51
+ def __init__(self):
52
+ super().__init__("ONNX Converter")
53
+
54
+ self.add_argument(
55
+ "--pipeline",
56
+ type=str,
57
+ choices=SUPPORTED_PIPELINES,
58
+ default="feature-extraction",
59
+ )
60
+ self.add_argument(
61
+ "--model",
62
+ type=str,
63
+ required=True,
64
+ help="Model's id or path (ex: bert-base-cased)",
65
+ )
66
+ self.add_argument("--tokenizer", type=str, help="Tokenizer's id or path (ex: bert-base-cased)")
67
+ self.add_argument(
68
+ "--framework",
69
+ type=str,
70
+ choices=["pt", "tf"],
71
+ help="Framework for loading the model",
72
+ )
73
+ self.add_argument("--opset", type=int, default=11, help="ONNX opset to use")
74
+ self.add_argument(
75
+ "--check-loading",
76
+ action="store_true",
77
+ help="Check ONNX is able to load the model",
78
+ )
79
+ self.add_argument(
80
+ "--use-external-format",
81
+ action="store_true",
82
+ help="Allow exporting model >= than 2Gb",
83
+ )
84
+ self.add_argument(
85
+ "--quantize",
86
+ action="store_true",
87
+ help="Quantize the neural network to be run with int8",
88
+ )
89
+ self.add_argument("output")
90
+
91
+
92
+ def generate_identified_filename(filename: Path, identifier: str) -> Path:
93
+ """
94
+ Append a string-identifier at the end (before the extension, if any) to the provided filepath
95
+
96
+ Args:
97
+ filename: pathlib.Path The actual path object we would like to add an identifier suffix
98
+ identifier: The suffix to add
99
+
100
+ Returns: String with concatenated identifier at the end of the filename
101
+ """
102
+ return filename.parent.joinpath(filename.stem + identifier).with_suffix(filename.suffix)
103
+
104
+
105
+ def check_onnxruntime_requirements(minimum_version: Version):
106
+ """
107
+ Check onnxruntime is installed and if the installed version match is recent enough
108
+
109
+ Raises:
110
+ ImportError: If onnxruntime is not installed or too old version is found
111
+ """
112
+ try:
113
+ import onnxruntime
114
+
115
+ # Parse the version of the installed onnxruntime
116
+ ort_version = parse(onnxruntime.__version__)
117
+
118
+ # We require 1.4.0 minimum
119
+ if ort_version < ORT_QUANTIZE_MINIMUM_VERSION:
120
+ raise ImportError(
121
+ f"We found an older version of onnxruntime ({onnxruntime.__version__}) "
122
+ f"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\n"
123
+ "Please update onnxruntime by running `pip install --upgrade onnxruntime`"
124
+ )
125
+
126
+ except ImportError:
127
+ raise ImportError(
128
+ "onnxruntime doesn't seem to be currently installed. "
129
+ "Please install the onnxruntime by running `pip install onnxruntime`"
130
+ " and relaunch the conversion."
131
+ )
132
+
133
+
134
+ def ensure_valid_input(model, tokens, input_names):
135
+ """
136
+ Ensure inputs are presented in the correct order, without any Non
137
+
138
+ Args:
139
+ model: The model used to forward the input data
140
+ tokens: BatchEncoding holding the input data
141
+ input_names: The name of the inputs
142
+
143
+ Returns: Tuple
144
+
145
+ """
146
+ print("Ensuring inputs are in correct order")
147
+
148
+ model_args_name = model.forward.__code__.co_varnames
149
+ model_args, ordered_input_names = [], []
150
+ for arg_name in model_args_name[1:]: # start at index 1 to skip "self" argument
151
+ if arg_name in input_names:
152
+ ordered_input_names.append(arg_name)
153
+ model_args.append(tokens[arg_name])
154
+ else:
155
+ print(f"{arg_name} is not present in the generated input list.")
156
+ break
157
+
158
+ print(f"Generated inputs order: {ordered_input_names}")
159
+ return ordered_input_names, tuple(model_args)
160
+
161
+
162
+ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
163
+ """
164
+ Attempt to infer the static vs dynamic axes for each input and output tensors for a specific model
165
+
166
+ Args:
167
+ nlp: The pipeline object holding the model to be exported
168
+ framework: The framework identifier to dispatch to the correct inference scheme (pt/tf)
169
+
170
+ Returns:
171
+
172
+ - List of the inferred input variable names
173
+ - List of the inferred output variable names
174
+ - Dictionary with input/output variables names as key and shape tensor as value
175
+ - a BatchEncoding reference which was used to infer all the above information
176
+ """
177
+
178
+ def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int):
179
+ if isinstance(tensor, (tuple, list)):
180
+ return [build_shape_dict(name, t, is_input, seq_len) for t in tensor]
181
+
182
+ else:
183
+ # Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)
184
+ axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: "batch"}
185
+ if is_input:
186
+ if len(tensor.shape) == 2:
187
+ axes[1] = "sequence"
188
+ else:
189
+ raise ValueError(f"Unable to infer tensor axes ({len(tensor.shape)})")
190
+ else:
191
+ seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]
192
+ axes.update({dim: "sequence" for dim in seq_axes})
193
+
194
+ print(f"Found {'input' if is_input else 'output'} {name} with shape: {axes}")
195
+ return axes
196
+
197
+ tokens = nlp.tokenizer("This is a sample output", return_tensors=framework)
198
+ seq_len = tokens.input_ids.shape[-1]
199
+ outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens)
200
+ if isinstance(outputs, ModelOutput):
201
+ outputs = outputs.to_tuple()
202
+ if not isinstance(outputs, (list, tuple)):
203
+ outputs = (outputs,)
204
+
205
+ # Generate input names & axes
206
+ input_vars = list(tokens.keys())
207
+ input_dynamic_axes = {k: build_shape_dict(k, v, True, seq_len) for k, v in tokens.items()}
208
+
209
+ # flatten potentially grouped outputs (past for gpt2, attentions)
210
+ outputs_flat = []
211
+ for output in outputs:
212
+ if isinstance(output, (tuple, list)):
213
+ outputs_flat.extend(output)
214
+ else:
215
+ outputs_flat.append(output)
216
+
217
+ # Generate output names & axes
218
+ output_names = [f"output_{i}" for i in range(len(outputs_flat))]
219
+ output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)}
220
+
221
+ # Create the aggregated axes representation
222
+ dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)
223
+ return input_vars, output_names, dynamic_axes, tokens
224
+
225
+
226
+ def load_graph_from_args(
227
+ pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None, **models_kwargs
228
+ ) -> Pipeline:
229
+ """
230
+ Convert the set of arguments provided through the CLI to an actual pipeline reference (tokenizer + model
231
+
232
+ Args:
233
+ pipeline_name: The kind of pipeline to use (ner, question-answering, etc.)
234
+ framework: The actual model to convert the pipeline from ("pt" or "tf")
235
+ model: The model name which will be loaded by the pipeline
236
+ tokenizer: The tokenizer name which will be loaded by the pipeline, default to the model's value
237
+
238
+ Returns: Pipeline object
239
+
240
+ """
241
+ # If no tokenizer provided
242
+ if tokenizer is None:
243
+ tokenizer = model
244
+
245
+ # Check the wanted framework is available
246
+ if framework == "pt" and not is_torch_available():
247
+ raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
248
+ if framework == "tf" and not is_tf_available():
249
+ raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
250
+
251
+ print(f"Loading pipeline (model: {model}, tokenizer: {tokenizer})")
252
+
253
+ # Allocate tokenizer and model
254
+ return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework, model_kwargs=models_kwargs)
255
+
256
+
257
+ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format: bool):
258
+ """
259
+ Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR
260
+
261
+ Args:
262
+ nlp: The pipeline to be exported
263
+ opset: The actual version of the ONNX operator set to use
264
+ output: Path where will be stored the generated ONNX model
265
+ use_external_format: Split the model definition from its parameters to allow model bigger than 2GB
266
+
267
+ Returns:
268
+
269
+ """
270
+ if not is_torch_available():
271
+ raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
272
+
273
+ import torch
274
+ from torch.onnx import export
275
+
276
+ from transformers.pytorch_utils import is_torch_less_than_1_11
277
+
278
+ print(f"Using framework PyTorch: {torch.__version__}")
279
+
280
+ with torch.no_grad():
281
+ input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
282
+ ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)
283
+
284
+ # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
285
+ # so we check the torch version for backwards compatibility
286
+ if is_torch_less_than_1_11:
287
+ export(
288
+ nlp.model,
289
+ model_args,
290
+ f=output.as_posix(),
291
+ input_names=ordered_input_names,
292
+ output_names=output_names,
293
+ dynamic_axes=dynamic_axes,
294
+ do_constant_folding=True,
295
+ use_external_data_format=use_external_format,
296
+ enable_onnx_checker=True,
297
+ opset_version=opset,
298
+ )
299
+ else:
300
+ export(
301
+ nlp.model,
302
+ model_args,
303
+ f=output.as_posix(),
304
+ input_names=ordered_input_names,
305
+ output_names=output_names,
306
+ dynamic_axes=dynamic_axes,
307
+ do_constant_folding=True,
308
+ opset_version=opset,
309
+ )
310
+
311
+
312
+ def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
313
+ """
314
+ Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR)
315
+
316
+ Args:
317
+ nlp: The pipeline to be exported
318
+ opset: The actual version of the ONNX operator set to use
319
+ output: Path where will be stored the generated ONNX model
320
+
321
+ Notes: TensorFlow cannot export model bigger than 2GB due to internal constraint from TensorFlow
322
+
323
+ """
324
+ if not is_tf_available():
325
+ raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
326
+
327
+ print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\")
328
+
329
+ try:
330
+ import tensorflow as tf
331
+ import tf2onnx
332
+ from tf2onnx import __version__ as t2ov
333
+
334
+ print(f"Using framework TensorFlow: {tf.version.VERSION}, tf2onnx: {t2ov}")
335
+
336
+ # Build
337
+ input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
338
+
339
+ # Forward
340
+ nlp.model.predict(tokens.data)
341
+ input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in tokens.items()]
342
+ model_proto, _ = tf2onnx.convert.from_keras(
343
+ nlp.model, input_signature, opset=opset, output_path=output.as_posix()
344
+ )
345
+
346
+ except ImportError as e:
347
+ raise Exception(
348
+ f"Cannot import {e.name} required to convert TF model to ONNX. Please install {e.name} first. {e}"
349
+ )
350
+
351
+
352
+ def convert(
353
+ framework: str,
354
+ model: str,
355
+ output: Path,
356
+ opset: int,
357
+ tokenizer: Optional[str] = None,
358
+ use_external_format: bool = False,
359
+ pipeline_name: str = "feature-extraction",
360
+ **model_kwargs,
361
+ ):
362
+ """
363
+ Convert the pipeline object to the ONNX Intermediate Representation (IR) format
364
+
365
+ Args:
366
+ framework: The framework the pipeline is backed by ("pt" or "tf")
367
+ model: The name of the model to load for the pipeline
368
+ output: The path where the ONNX graph will be stored
369
+ opset: The actual version of the ONNX operator set to use
370
+ tokenizer: The name of the model to load for the pipeline, default to the model's name if not provided
371
+ use_external_format:
372
+ Split the model definition from its parameters to allow model bigger than 2GB (PyTorch only)
373
+ pipeline_name: The kind of pipeline to instantiate (ner, question-answering, etc.)
374
+ model_kwargs: Keyword arguments to be forwarded to the model constructor
375
+
376
+ Returns:
377
+
378
+ """
379
+ warnings.warn(
380
+ "The `transformers.convert_graph_to_onnx` package is deprecated and will be removed in version 5 of"
381
+ " Transformers",
382
+ FutureWarning,
383
+ )
384
+ print(f"ONNX opset version set to: {opset}")
385
+
386
+ # Load the pipeline
387
+ nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer, **model_kwargs)
388
+
389
+ if not output.parent.exists():
390
+ print(f"Creating folder {output.parent}")
391
+ makedirs(output.parent.as_posix())
392
+ elif len(listdir(output.parent.as_posix())) > 0:
393
+ raise Exception(f"Folder {output.parent.as_posix()} is not empty, aborting conversion")
394
+
395
+ # Export the graph
396
+ if framework == "pt":
397
+ convert_pytorch(nlp, opset, output, use_external_format)
398
+ else:
399
+ convert_tensorflow(nlp, opset, output)
400
+
401
+
402
+ def optimize(onnx_model_path: Path) -> Path:
403
+ """
404
+ Load the model at the specified path and let onnxruntime look at transformations on the graph to enable all the
405
+ optimizations possible
406
+
407
+ Args:
408
+ onnx_model_path: filepath where the model binary description is stored
409
+
410
+ Returns: Path where the optimized model binary description has been saved
411
+
412
+ """
413
+ from onnxruntime import InferenceSession, SessionOptions
414
+
415
+ # Generate model name with suffix "optimized"
416
+ opt_model_path = generate_identified_filename(onnx_model_path, "-optimized")
417
+ sess_option = SessionOptions()
418
+ sess_option.optimized_model_filepath = opt_model_path.as_posix()
419
+ _ = InferenceSession(onnx_model_path.as_posix(), sess_option)
420
+
421
+ print(f"Optimized model has been written at {opt_model_path}: \N{heavy check mark}")
422
+ print("/!\\ Optimized model contains hardware specific operators which might not be portable. /!\\")
423
+
424
+ return opt_model_path
425
+
426
+
427
+ def quantize(onnx_model_path: Path) -> Path:
428
+ """
429
+ Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU
430
+
431
+ Args:
432
+ onnx_model_path: Path to location the exported ONNX model is stored
433
+
434
+ Returns: The Path generated for the quantized
435
+ """
436
+ import onnx
437
+ import onnxruntime
438
+ from onnx.onnx_pb import ModelProto
439
+ from onnxruntime.quantization import QuantizationMode
440
+ from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
441
+ from onnxruntime.quantization.registry import IntegerOpsRegistry
442
+
443
+ # Load the ONNX model
444
+ onnx_model = onnx.load(onnx_model_path.as_posix())
445
+
446
+ if parse(onnx.__version__) < parse("1.5.0"):
447
+ print(
448
+ "Models larger than 2GB will fail to quantize due to protobuf constraint.\n"
449
+ "Please upgrade to onnxruntime >= 1.5.0."
450
+ )
451
+
452
+ # Copy it
453
+ copy_model = ModelProto()
454
+ copy_model.CopyFrom(onnx_model)
455
+
456
+ # Construct quantizer
457
+ # onnxruntime renamed input_qType to activation_qType in v1.13.1, so we
458
+ # check the onnxruntime version to ensure backward compatibility.
459
+ # See also: https://github.com/microsoft/onnxruntime/pull/12873
460
+ if parse(onnxruntime.__version__) < parse("1.13.1"):
461
+ quantizer = ONNXQuantizer(
462
+ model=copy_model,
463
+ per_channel=False,
464
+ reduce_range=False,
465
+ mode=QuantizationMode.IntegerOps,
466
+ static=False,
467
+ weight_qType=True,
468
+ input_qType=False,
469
+ tensors_range=None,
470
+ nodes_to_quantize=None,
471
+ nodes_to_exclude=None,
472
+ op_types_to_quantize=list(IntegerOpsRegistry),
473
+ )
474
+ else:
475
+ quantizer = ONNXQuantizer(
476
+ model=copy_model,
477
+ per_channel=False,
478
+ reduce_range=False,
479
+ mode=QuantizationMode.IntegerOps,
480
+ static=False,
481
+ weight_qType=True,
482
+ activation_qType=False,
483
+ tensors_range=None,
484
+ nodes_to_quantize=None,
485
+ nodes_to_exclude=None,
486
+ op_types_to_quantize=list(IntegerOpsRegistry),
487
+ )
488
+
489
+ # Quantize and export
490
+ quantizer.quantize_model()
491
+
492
+ # Append "-quantized" at the end of the model's name
493
+ quantized_model_path = generate_identified_filename(onnx_model_path, "-quantized")
494
+
495
+ # Save model
496
+ print(f"Quantized model has been written at {quantized_model_path}: \N{heavy check mark}")
497
+ onnx.save_model(quantizer.model.model, quantized_model_path.as_posix())
498
+
499
+ return quantized_model_path
500
+
501
+
502
+ def verify(path: Path):
503
+ from onnxruntime import InferenceSession, SessionOptions
504
+ from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException
505
+
506
+ print(f"Checking ONNX model loading from: {path} ...")
507
+ try:
508
+ onnx_options = SessionOptions()
509
+ _ = InferenceSession(path.as_posix(), onnx_options, providers=["CPUExecutionProvider"])
510
+ print(f"Model {path} correctly loaded: \N{heavy check mark}")
511
+ except RuntimeException as re:
512
+ print(f"Error while loading the model {re}: \N{heavy ballot x}")
513
+
514
+
515
+ if __name__ == "__main__":
516
+ parser = OnnxConverterArgumentParser()
517
+ args = parser.parse_args()
518
+
519
+ # Make sure output is absolute path
520
+ args.output = Path(args.output).absolute()
521
+
522
+ try:
523
+ print("\n====== Converting model to ONNX ======")
524
+ # Convert
525
+ convert(
526
+ args.framework,
527
+ args.model,
528
+ args.output,
529
+ args.opset,
530
+ args.tokenizer,
531
+ args.use_external_format,
532
+ args.pipeline,
533
+ )
534
+
535
+ if args.quantize:
536
+ # Ensure requirements for quantization on onnxruntime is met
537
+ check_onnxruntime_requirements(ORT_QUANTIZE_MINIMUM_VERSION)
538
+
539
+ # onnxruntime optimizations doesn't provide the same level of performances on TensorFlow than PyTorch
540
+ if args.framework == "tf":
541
+ print(
542
+ "\t Using TensorFlow might not provide the same optimization level compared to PyTorch.\n"
543
+ "\t For TensorFlow users you can try optimizing the model directly through onnxruntime_tools.\n"
544
+ "\t For more information, please refer to the onnxruntime documentation:\n"
545
+ "\t\thttps://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers\n"
546
+ )
547
+
548
+ print("\n====== Optimizing ONNX model ======")
549
+
550
+ # Quantization works best when using the optimized version of the model
551
+ args.optimized_output = optimize(args.output)
552
+
553
+ # Do the quantization on the right graph
554
+ args.quantized_output = quantize(args.optimized_output)
555
+
556
+ # And verify
557
+ if args.check_loading:
558
+ print("\n====== Check exported ONNX model(s) ======")
559
+ verify(args.output)
560
+
561
+ if hasattr(args, "optimized_output"):
562
+ verify(args.optimized_output)
563
+
564
+ if hasattr(args, "quantized_output"):
565
+ verify(args.quantized_output)
566
+
567
+ except Exception as e:
568
+ print(f"Error while converting the model: {e}")
569
+ exit(1)
transformers_4_35_0/convert_pytorch_checkpoint_to_tf2.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Convert pytorch checkpoints to TensorFlow"""
16
+
17
+
18
+ import argparse
19
+ import os
20
+
21
+ from . import (
22
+ ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
23
+ BART_PRETRAINED_MODEL_ARCHIVE_LIST,
24
+ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
25
+ CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
26
+ CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
27
+ DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
28
+ DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
29
+ DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
30
+ DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
31
+ ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
32
+ FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
33
+ GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
34
+ LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
35
+ LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
36
+ OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
37
+ ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
38
+ T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
39
+ TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
40
+ WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
41
+ XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
42
+ XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
43
+ XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
44
+ AlbertConfig,
45
+ BartConfig,
46
+ BertConfig,
47
+ CamembertConfig,
48
+ CTRLConfig,
49
+ DistilBertConfig,
50
+ DPRConfig,
51
+ ElectraConfig,
52
+ FlaubertConfig,
53
+ GPT2Config,
54
+ LayoutLMConfig,
55
+ LxmertConfig,
56
+ OpenAIGPTConfig,
57
+ RobertaConfig,
58
+ T5Config,
59
+ TFAlbertForPreTraining,
60
+ TFBartForConditionalGeneration,
61
+ TFBartForSequenceClassification,
62
+ TFBertForPreTraining,
63
+ TFBertForQuestionAnswering,
64
+ TFBertForSequenceClassification,
65
+ TFCamembertForMaskedLM,
66
+ TFCTRLLMHeadModel,
67
+ TFDistilBertForMaskedLM,
68
+ TFDistilBertForQuestionAnswering,
69
+ TFDPRContextEncoder,
70
+ TFDPRQuestionEncoder,
71
+ TFDPRReader,
72
+ TFElectraForPreTraining,
73
+ TFFlaubertWithLMHeadModel,
74
+ TFGPT2LMHeadModel,
75
+ TFLayoutLMForMaskedLM,
76
+ TFLxmertForPreTraining,
77
+ TFLxmertVisualFeatureEncoder,
78
+ TFOpenAIGPTLMHeadModel,
79
+ TFRobertaForCausalLM,
80
+ TFRobertaForMaskedLM,
81
+ TFRobertaForSequenceClassification,
82
+ TFT5ForConditionalGeneration,
83
+ TFTransfoXLLMHeadModel,
84
+ TFWav2Vec2Model,
85
+ TFXLMRobertaForMaskedLM,
86
+ TFXLMWithLMHeadModel,
87
+ TFXLNetLMHeadModel,
88
+ TransfoXLConfig,
89
+ Wav2Vec2Config,
90
+ Wav2Vec2Model,
91
+ XLMConfig,
92
+ XLMRobertaConfig,
93
+ XLNetConfig,
94
+ is_torch_available,
95
+ load_pytorch_checkpoint_in_tf2_model,
96
+ )
97
+ from .utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, logging
98
+
99
+
100
+ if is_torch_available():
101
+ import numpy as np
102
+ import torch
103
+
104
+ from . import (
105
+ AlbertForPreTraining,
106
+ BartForConditionalGeneration,
107
+ BertForPreTraining,
108
+ BertForQuestionAnswering,
109
+ BertForSequenceClassification,
110
+ CamembertForMaskedLM,
111
+ CTRLLMHeadModel,
112
+ DistilBertForMaskedLM,
113
+ DistilBertForQuestionAnswering,
114
+ DPRContextEncoder,
115
+ DPRQuestionEncoder,
116
+ DPRReader,
117
+ ElectraForPreTraining,
118
+ FlaubertWithLMHeadModel,
119
+ GPT2LMHeadModel,
120
+ LayoutLMForMaskedLM,
121
+ LxmertForPreTraining,
122
+ LxmertVisualFeatureEncoder,
123
+ OpenAIGPTLMHeadModel,
124
+ RobertaForMaskedLM,
125
+ RobertaForSequenceClassification,
126
+ T5ForConditionalGeneration,
127
+ TransfoXLLMHeadModel,
128
+ XLMRobertaForMaskedLM,
129
+ XLMWithLMHeadModel,
130
+ XLNetLMHeadModel,
131
+ )
132
+
133
+
134
+ logging.set_verbosity_info()
135
+
136
+ MODEL_CLASSES = {
137
+ "bart": (
138
+ BartConfig,
139
+ TFBartForConditionalGeneration,
140
+ TFBartForSequenceClassification,
141
+ BartForConditionalGeneration,
142
+ BART_PRETRAINED_MODEL_ARCHIVE_LIST,
143
+ ),
144
+ "bert": (
145
+ BertConfig,
146
+ TFBertForPreTraining,
147
+ BertForPreTraining,
148
+ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
149
+ ),
150
+ "bert-large-uncased-whole-word-masking-finetuned-squad": (
151
+ BertConfig,
152
+ TFBertForQuestionAnswering,
153
+ BertForQuestionAnswering,
154
+ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
155
+ ),
156
+ "bert-large-cased-whole-word-masking-finetuned-squad": (
157
+ BertConfig,
158
+ TFBertForQuestionAnswering,
159
+ BertForQuestionAnswering,
160
+ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
161
+ ),
162
+ "bert-base-cased-finetuned-mrpc": (
163
+ BertConfig,
164
+ TFBertForSequenceClassification,
165
+ BertForSequenceClassification,
166
+ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
167
+ ),
168
+ "dpr": (
169
+ DPRConfig,
170
+ TFDPRQuestionEncoder,
171
+ TFDPRContextEncoder,
172
+ TFDPRReader,
173
+ DPRQuestionEncoder,
174
+ DPRContextEncoder,
175
+ DPRReader,
176
+ DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
177
+ DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
178
+ DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
179
+ ),
180
+ "gpt2": (
181
+ GPT2Config,
182
+ TFGPT2LMHeadModel,
183
+ GPT2LMHeadModel,
184
+ GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
185
+ ),
186
+ "xlnet": (
187
+ XLNetConfig,
188
+ TFXLNetLMHeadModel,
189
+ XLNetLMHeadModel,
190
+ XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
191
+ ),
192
+ "xlm": (
193
+ XLMConfig,
194
+ TFXLMWithLMHeadModel,
195
+ XLMWithLMHeadModel,
196
+ XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
197
+ ),
198
+ "xlm-roberta": (
199
+ XLMRobertaConfig,
200
+ TFXLMRobertaForMaskedLM,
201
+ XLMRobertaForMaskedLM,
202
+ XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
203
+ ),
204
+ "transfo-xl": (
205
+ TransfoXLConfig,
206
+ TFTransfoXLLMHeadModel,
207
+ TransfoXLLMHeadModel,
208
+ TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
209
+ ),
210
+ "openai-gpt": (
211
+ OpenAIGPTConfig,
212
+ TFOpenAIGPTLMHeadModel,
213
+ OpenAIGPTLMHeadModel,
214
+ OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
215
+ ),
216
+ "roberta": (
217
+ RobertaConfig,
218
+ TFRobertaForCausalLM,
219
+ TFRobertaForMaskedLM,
220
+ RobertaForMaskedLM,
221
+ ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
222
+ ),
223
+ "layoutlm": (
224
+ LayoutLMConfig,
225
+ TFLayoutLMForMaskedLM,
226
+ LayoutLMForMaskedLM,
227
+ LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
228
+ ),
229
+ "roberta-large-mnli": (
230
+ RobertaConfig,
231
+ TFRobertaForSequenceClassification,
232
+ RobertaForSequenceClassification,
233
+ ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
234
+ ),
235
+ "camembert": (
236
+ CamembertConfig,
237
+ TFCamembertForMaskedLM,
238
+ CamembertForMaskedLM,
239
+ CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
240
+ ),
241
+ "flaubert": (
242
+ FlaubertConfig,
243
+ TFFlaubertWithLMHeadModel,
244
+ FlaubertWithLMHeadModel,
245
+ FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
246
+ ),
247
+ "distilbert": (
248
+ DistilBertConfig,
249
+ TFDistilBertForMaskedLM,
250
+ DistilBertForMaskedLM,
251
+ DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
252
+ ),
253
+ "distilbert-base-distilled-squad": (
254
+ DistilBertConfig,
255
+ TFDistilBertForQuestionAnswering,
256
+ DistilBertForQuestionAnswering,
257
+ DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
258
+ ),
259
+ "lxmert": (
260
+ LxmertConfig,
261
+ TFLxmertForPreTraining,
262
+ LxmertForPreTraining,
263
+ LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
264
+ ),
265
+ "lxmert-visual-feature-encoder": (
266
+ LxmertConfig,
267
+ TFLxmertVisualFeatureEncoder,
268
+ LxmertVisualFeatureEncoder,
269
+ LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
270
+ ),
271
+ "ctrl": (
272
+ CTRLConfig,
273
+ TFCTRLLMHeadModel,
274
+ CTRLLMHeadModel,
275
+ CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
276
+ ),
277
+ "albert": (
278
+ AlbertConfig,
279
+ TFAlbertForPreTraining,
280
+ AlbertForPreTraining,
281
+ ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
282
+ ),
283
+ "t5": (
284
+ T5Config,
285
+ TFT5ForConditionalGeneration,
286
+ T5ForConditionalGeneration,
287
+ T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
288
+ ),
289
+ "electra": (
290
+ ElectraConfig,
291
+ TFElectraForPreTraining,
292
+ ElectraForPreTraining,
293
+ ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
294
+ ),
295
+ "wav2vec2": (
296
+ Wav2Vec2Config,
297
+ TFWav2Vec2Model,
298
+ Wav2Vec2Model,
299
+ WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
300
+ ),
301
+ }
302
+
303
+
304
+ def convert_pt_checkpoint_to_tf(
305
+ model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True
306
+ ):
307
+ if model_type not in MODEL_CLASSES:
308
+ raise ValueError(f"Unrecognized model type, should be one of {list(MODEL_CLASSES.keys())}.")
309
+
310
+ config_class, model_class, pt_model_class, aws_config_map = MODEL_CLASSES[model_type]
311
+
312
+ # Initialise TF model
313
+ if config_file in aws_config_map:
314
+ config_file = cached_file(config_file, CONFIG_NAME, force_download=not use_cached_models)
315
+ config = config_class.from_json_file(config_file)
316
+ config.output_hidden_states = True
317
+ config.output_attentions = True
318
+ print(f"Building TensorFlow model from configuration: {config}")
319
+ tf_model = model_class(config)
320
+
321
+ # Load weights from tf checkpoint
322
+ if pytorch_checkpoint_path in aws_config_map.keys():
323
+ pytorch_checkpoint_path = cached_file(
324
+ pytorch_checkpoint_path, WEIGHTS_NAME, force_download=not use_cached_models
325
+ )
326
+ # Load PyTorch checkpoint in tf2 model:
327
+ tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
328
+
329
+ if compare_with_pt_model:
330
+ tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
331
+
332
+ state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
333
+ pt_model = pt_model_class.from_pretrained(
334
+ pretrained_model_name_or_path=None, config=config, state_dict=state_dict
335
+ )
336
+
337
+ with torch.no_grad():
338
+ pto = pt_model(**pt_model.dummy_inputs)
339
+
340
+ np_pt = pto[0].numpy()
341
+ np_tf = tfo[0].numpy()
342
+ diff = np.amax(np.abs(np_pt - np_tf))
343
+ print(f"Max absolute difference between models outputs {diff}")
344
+ assert diff <= 2e-2, f"Error, model absolute difference is >2e-2: {diff}"
345
+
346
+ # Save pytorch-model
347
+ print(f"Save TensorFlow model to {tf_dump_path}")
348
+ tf_model.save_weights(tf_dump_path, save_format="h5")
349
+
350
+
351
+ def convert_all_pt_checkpoints_to_tf(
352
+ args_model_type,
353
+ tf_dump_path,
354
+ model_shortcut_names_or_path=None,
355
+ config_shortcut_names_or_path=None,
356
+ compare_with_pt_model=False,
357
+ use_cached_models=False,
358
+ remove_cached_files=False,
359
+ only_convert_finetuned_models=False,
360
+ ):
361
+ if args_model_type is None:
362
+ model_types = list(MODEL_CLASSES.keys())
363
+ else:
364
+ model_types = [args_model_type]
365
+
366
+ for j, model_type in enumerate(model_types, start=1):
367
+ print("=" * 100)
368
+ print(f" Converting model type {j}/{len(model_types)}: {model_type}")
369
+ print("=" * 100)
370
+ if model_type not in MODEL_CLASSES:
371
+ raise ValueError(f"Unrecognized model type {model_type}, should be one of {list(MODEL_CLASSES.keys())}.")
372
+
373
+ config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
374
+
375
+ if model_shortcut_names_or_path is None:
376
+ model_shortcut_names_or_path = list(aws_model_maps.keys())
377
+ if config_shortcut_names_or_path is None:
378
+ config_shortcut_names_or_path = model_shortcut_names_or_path
379
+
380
+ for i, (model_shortcut_name, config_shortcut_name) in enumerate(
381
+ zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1
382
+ ):
383
+ print("-" * 100)
384
+ if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name:
385
+ if not only_convert_finetuned_models:
386
+ print(f" Skipping finetuned checkpoint {model_shortcut_name}")
387
+ continue
388
+ model_type = model_shortcut_name
389
+ elif only_convert_finetuned_models:
390
+ print(f" Skipping not finetuned checkpoint {model_shortcut_name}")
391
+ continue
392
+ print(
393
+ f" Converting checkpoint {i}/{len(aws_config_map)}: {model_shortcut_name} - model_type {model_type}"
394
+ )
395
+ print("-" * 100)
396
+
397
+ if config_shortcut_name in aws_config_map:
398
+ config_file = cached_file(config_shortcut_name, CONFIG_NAME, force_download=not use_cached_models)
399
+ else:
400
+ config_file = config_shortcut_name
401
+
402
+ if model_shortcut_name in aws_model_maps:
403
+ model_file = cached_file(model_shortcut_name, WEIGHTS_NAME, force_download=not use_cached_models)
404
+ else:
405
+ model_file = model_shortcut_name
406
+
407
+ if os.path.isfile(model_shortcut_name):
408
+ model_shortcut_name = "converted_model"
409
+
410
+ convert_pt_checkpoint_to_tf(
411
+ model_type=model_type,
412
+ pytorch_checkpoint_path=model_file,
413
+ config_file=config_file,
414
+ tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"),
415
+ compare_with_pt_model=compare_with_pt_model,
416
+ )
417
+ if remove_cached_files:
418
+ os.remove(config_file)
419
+ os.remove(model_file)
420
+
421
+
422
+ if __name__ == "__main__":
423
+ parser = argparse.ArgumentParser()
424
+ # Required parameters
425
+ parser.add_argument(
426
+ "--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
427
+ )
428
+ parser.add_argument(
429
+ "--model_type",
430
+ default=None,
431
+ type=str,
432
+ help=(
433
+ f"Model type selected in the list of {list(MODEL_CLASSES.keys())}. If not given, will download and "
434
+ "convert all the models from AWS."
435
+ ),
436
+ )
437
+ parser.add_argument(
438
+ "--pytorch_checkpoint_path",
439
+ default=None,
440
+ type=str,
441
+ help=(
442
+ "Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
443
+ "If not given, will download and convert all the checkpoints from AWS."
444
+ ),
445
+ )
446
+ parser.add_argument(
447
+ "--config_file",
448
+ default=None,
449
+ type=str,
450
+ help=(
451
+ "The config json file corresponding to the pre-trained model. \n"
452
+ "This specifies the model architecture. If not given and "
453
+ "--pytorch_checkpoint_path is not given or is a shortcut name "
454
+ "use the configuration associated to the shortcut name on the AWS"
455
+ ),
456
+ )
457
+ parser.add_argument(
458
+ "--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions."
459
+ )
460
+ parser.add_argument(
461
+ "--use_cached_models",
462
+ action="store_true",
463
+ help="Use cached models if possible instead of updating to latest checkpoint versions.",
464
+ )
465
+ parser.add_argument(
466
+ "--remove_cached_files",
467
+ action="store_true",
468
+ help="Remove pytorch models after conversion (save memory when converting in batches).",
469
+ )
470
+ parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.")
471
+ args = parser.parse_args()
472
+
473
+ # if args.pytorch_checkpoint_path is not None:
474
+ # convert_pt_checkpoint_to_tf(args.model_type.lower(),
475
+ # args.pytorch_checkpoint_path,
476
+ # args.config_file if args.config_file is not None else args.pytorch_checkpoint_path,
477
+ # args.tf_dump_path,
478
+ # compare_with_pt_model=args.compare_with_pt_model,
479
+ # use_cached_models=args.use_cached_models)
480
+ # else:
481
+ convert_all_pt_checkpoints_to_tf(
482
+ args.model_type.lower() if args.model_type is not None else None,
483
+ args.tf_dump_path,
484
+ model_shortcut_names_or_path=[args.pytorch_checkpoint_path]
485
+ if args.pytorch_checkpoint_path is not None
486
+ else None,
487
+ config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
488
+ compare_with_pt_model=args.compare_with_pt_model,
489
+ use_cached_models=args.use_cached_models,
490
+ remove_cached_files=args.remove_cached_files,
491
+ only_convert_finetuned_models=args.only_convert_finetuned_models,
492
+ )
transformers_4_35_0/convert_slow_tokenizer.py ADDED
@@ -0,0 +1,1318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Utilities to convert slow tokenizers in their fast tokenizers counterparts.
17
+
18
+ All the conversions are grouped here to gather SentencePiece dependencies outside of the fast tokenizers files and
19
+ allow to make our dependency on SentencePiece optional.
20
+ """
21
+
22
+ import warnings
23
+ from typing import Dict, List, Tuple
24
+
25
+ from packaging import version
26
+ from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
27
+ from tokenizers.models import BPE, Unigram, WordPiece
28
+
29
+ from .utils import is_protobuf_available, requires_backends
30
+ from .utils.import_utils import PROTOBUF_IMPORT_ERROR
31
+
32
+
33
+ def import_protobuf(error_message=""):
34
+ if is_protobuf_available():
35
+ import google.protobuf
36
+
37
+ if version.parse(google.protobuf.__version__) < version.parse("4.0.0"):
38
+ from transformers.utils import sentencepiece_model_pb2
39
+ else:
40
+ from transformers.utils import sentencepiece_model_pb2_new as sentencepiece_model_pb2
41
+ return sentencepiece_model_pb2
42
+ else:
43
+ raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
44
+
45
+
46
+ class SentencePieceExtractor:
47
+ """
48
+ Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
49
+ """
50
+
51
+ def __init__(self, model: str):
52
+ requires_backends(self, "sentencepiece")
53
+ from sentencepiece import SentencePieceProcessor
54
+
55
+ self.sp = SentencePieceProcessor()
56
+ self.sp.Load(model)
57
+
58
+ def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
59
+ """
60
+ By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
61
+ order the merges with respect to the piece scores instead.
62
+ """
63
+ sp = self.sp
64
+ vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
65
+ if vocab_scores is not None:
66
+ vocab_scores, reverse = dict(vocab_scores), True
67
+ else:
68
+ vocab_scores, reverse = vocab, False
69
+
70
+ # Merges
71
+ merges = []
72
+ for merge, piece_score in vocab_scores.items():
73
+ local = []
74
+ for index in range(1, len(merge)):
75
+ piece_l, piece_r = merge[:index], merge[index:]
76
+ if piece_l in vocab and piece_r in vocab:
77
+ local.append((piece_l, piece_r, piece_score))
78
+ local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
79
+ merges.extend(local)
80
+
81
+ merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
82
+ merges = [(val[0], val[1]) for val in merges]
83
+ return vocab, merges
84
+
85
+
86
+ def check_number_comma(piece: str) -> bool:
87
+ return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
88
+
89
+
90
+ class Converter:
91
+ def __init__(self, original_tokenizer):
92
+ self.original_tokenizer = original_tokenizer
93
+
94
+ def converted(self) -> Tokenizer:
95
+ raise NotImplementedError()
96
+
97
+
98
+ class BertConverter(Converter):
99
+ def converted(self) -> Tokenizer:
100
+ vocab = self.original_tokenizer.vocab
101
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
102
+
103
+ tokenize_chinese_chars = False
104
+ strip_accents = False
105
+ do_lower_case = False
106
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
107
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
108
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
109
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
110
+
111
+ tokenizer.normalizer = normalizers.BertNormalizer(
112
+ clean_text=True,
113
+ handle_chinese_chars=tokenize_chinese_chars,
114
+ strip_accents=strip_accents,
115
+ lowercase=do_lower_case,
116
+ )
117
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
118
+
119
+ cls = str(self.original_tokenizer.cls_token)
120
+ sep = str(self.original_tokenizer.sep_token)
121
+ cls_token_id = self.original_tokenizer.cls_token_id
122
+ sep_token_id = self.original_tokenizer.sep_token_id
123
+
124
+ tokenizer.post_processor = processors.TemplateProcessing(
125
+ single=f"{cls}:0 $A:0 {sep}:0",
126
+ pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
127
+ special_tokens=[
128
+ (cls, cls_token_id),
129
+ (sep, sep_token_id),
130
+ ],
131
+ )
132
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
133
+
134
+ return tokenizer
135
+
136
+
137
+ class SplinterConverter(Converter):
138
+ def converted(self) -> Tokenizer:
139
+ vocab = self.original_tokenizer.vocab
140
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
141
+
142
+ tokenize_chinese_chars = False
143
+ strip_accents = False
144
+ do_lower_case = False
145
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
146
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
147
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
148
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
149
+
150
+ tokenizer.normalizer = normalizers.BertNormalizer(
151
+ clean_text=True,
152
+ handle_chinese_chars=tokenize_chinese_chars,
153
+ strip_accents=strip_accents,
154
+ lowercase=do_lower_case,
155
+ )
156
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
157
+
158
+ cls = str(self.original_tokenizer.cls_token)
159
+ sep = str(self.original_tokenizer.sep_token)
160
+ question = str(self.original_tokenizer.question_token)
161
+ dot = "."
162
+ cls_token_id = self.original_tokenizer.cls_token_id
163
+ sep_token_id = self.original_tokenizer.sep_token_id
164
+ question_token_id = self.original_tokenizer.question_token_id
165
+ dot_token_id = self.original_tokenizer.convert_tokens_to_ids(".")
166
+
167
+ if self.original_tokenizer.padding_side == "right":
168
+ pair = f"{cls}:0 $A:0 {question} {dot} {sep}:0 $B:1 {sep}:1"
169
+ else:
170
+ pair = f"{cls}:0 $A:0 {sep}:0 $B:1 {question} {dot} {sep}:1"
171
+
172
+ tokenizer.post_processor = processors.TemplateProcessing(
173
+ single=f"{cls}:0 $A:0 {sep}:0",
174
+ pair=pair,
175
+ special_tokens=[
176
+ (cls, cls_token_id),
177
+ (sep, sep_token_id),
178
+ (question, question_token_id),
179
+ (dot, dot_token_id),
180
+ ],
181
+ )
182
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
183
+
184
+ return tokenizer
185
+
186
+
187
+ class FunnelConverter(Converter):
188
+ def converted(self) -> Tokenizer:
189
+ vocab = self.original_tokenizer.vocab
190
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
191
+
192
+ tokenize_chinese_chars = False
193
+ strip_accents = False
194
+ do_lower_case = False
195
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
196
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
197
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
198
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
199
+
200
+ tokenizer.normalizer = normalizers.BertNormalizer(
201
+ clean_text=True,
202
+ handle_chinese_chars=tokenize_chinese_chars,
203
+ strip_accents=strip_accents,
204
+ lowercase=do_lower_case,
205
+ )
206
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
207
+
208
+ cls = str(self.original_tokenizer.cls_token)
209
+ sep = str(self.original_tokenizer.sep_token)
210
+ cls_token_id = self.original_tokenizer.cls_token_id
211
+ sep_token_id = self.original_tokenizer.sep_token_id
212
+
213
+ tokenizer.post_processor = processors.TemplateProcessing(
214
+ single=f"{cls}:2 $A:0 {sep}:0", # token_type_id is 2 for Funnel transformer
215
+ pair=f"{cls}:2 $A:0 {sep}:0 $B:1 {sep}:1",
216
+ special_tokens=[
217
+ (cls, cls_token_id),
218
+ (sep, sep_token_id),
219
+ ],
220
+ )
221
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
222
+
223
+ return tokenizer
224
+
225
+
226
+ class MPNetConverter(Converter):
227
+ def converted(self) -> Tokenizer:
228
+ vocab = self.original_tokenizer.vocab
229
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
230
+
231
+ tokenize_chinese_chars = False
232
+ strip_accents = False
233
+ do_lower_case = False
234
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
235
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
236
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
237
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
238
+
239
+ tokenizer.normalizer = normalizers.BertNormalizer(
240
+ clean_text=True,
241
+ handle_chinese_chars=tokenize_chinese_chars,
242
+ strip_accents=strip_accents,
243
+ lowercase=do_lower_case,
244
+ )
245
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
246
+
247
+ cls = str(self.original_tokenizer.cls_token)
248
+ sep = str(self.original_tokenizer.sep_token)
249
+ cls_token_id = self.original_tokenizer.cls_token_id
250
+ sep_token_id = self.original_tokenizer.sep_token_id
251
+
252
+ tokenizer.post_processor = processors.TemplateProcessing(
253
+ single=f"{cls}:0 $A:0 {sep}:0",
254
+ pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1", # MPNet uses two [SEP] tokens
255
+ special_tokens=[
256
+ (cls, cls_token_id),
257
+ (sep, sep_token_id),
258
+ ],
259
+ )
260
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
261
+
262
+ return tokenizer
263
+
264
+
265
+ class OpenAIGPTConverter(Converter):
266
+ def converted(self) -> Tokenizer:
267
+ vocab = self.original_tokenizer.encoder
268
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
269
+ unk_token = self.original_tokenizer.unk_token
270
+
271
+ tokenizer = Tokenizer(
272
+ BPE(
273
+ vocab=vocab,
274
+ merges=merges,
275
+ dropout=None,
276
+ unk_token=str(unk_token),
277
+ end_of_word_suffix="</w>",
278
+ fuse_unk=False,
279
+ )
280
+ )
281
+
282
+ if tokenizer.token_to_id(str(unk_token)) is not None:
283
+ tokenizer.add_special_tokens([str(unk_token)])
284
+
285
+ tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True)
286
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
287
+ tokenizer.decoder = decoders.BPEDecoder(suffix="</w>")
288
+
289
+ return tokenizer
290
+
291
+
292
+ class GPT2Converter(Converter):
293
+ def converted(self) -> Tokenizer:
294
+ vocab = self.original_tokenizer.encoder
295
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
296
+
297
+ tokenizer = Tokenizer(
298
+ BPE(
299
+ vocab=vocab,
300
+ merges=merges,
301
+ dropout=None,
302
+ continuing_subword_prefix="",
303
+ end_of_word_suffix="",
304
+ fuse_unk=False,
305
+ )
306
+ )
307
+
308
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
309
+ tokenizer.decoder = decoders.ByteLevel()
310
+ if self.original_tokenizer.add_bos_token:
311
+ bos = self.original_tokenizer.bos_token
312
+ bos_token_id = self.original_tokenizer.bos_token_id
313
+ tokenizer.post_processor = processors.TemplateProcessing(
314
+ single=f"{bos}:0 $A:0",
315
+ pair=f"{bos}:0 $A:0 $B:1",
316
+ special_tokens=[
317
+ (bos, bos_token_id),
318
+ ],
319
+ )
320
+ else:
321
+ # XXX trim_offsets=False actually means this post_processor doesn't
322
+ # really do anything.
323
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
324
+ return tokenizer
325
+
326
+
327
+ class HerbertConverter(Converter):
328
+ def converted(self) -> Tokenizer:
329
+ tokenizer_info_str = "#version:"
330
+ token_suffix = "</w>"
331
+
332
+ vocab = self.original_tokenizer.encoder
333
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
334
+ if tokenizer_info_str in merges[0][0]:
335
+ merges = merges[1:]
336
+
337
+ tokenizer = Tokenizer(
338
+ BPE(
339
+ vocab,
340
+ merges,
341
+ dropout=None,
342
+ unk_token=self.original_tokenizer.unk_token,
343
+ end_of_word_suffix=token_suffix,
344
+ )
345
+ )
346
+
347
+ tokenizer.normalizer = normalizers.BertNormalizer(lowercase=False, strip_accents=False)
348
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
349
+ tokenizer.decoder = decoders.BPEDecoder(suffix=token_suffix)
350
+ tokenizer.post_processor = processors.BertProcessing(
351
+ sep=(self.original_tokenizer.sep_token, self.original_tokenizer.sep_token_id),
352
+ cls=(self.original_tokenizer.cls_token, self.original_tokenizer.cls_token_id),
353
+ )
354
+
355
+ return tokenizer
356
+
357
+
358
+ class RobertaConverter(Converter):
359
+ def converted(self) -> Tokenizer:
360
+ ot = self.original_tokenizer
361
+ vocab = ot.encoder
362
+ merges = list(ot.bpe_ranks.keys())
363
+
364
+ tokenizer = Tokenizer(
365
+ BPE(
366
+ vocab=vocab,
367
+ merges=merges,
368
+ dropout=None,
369
+ continuing_subword_prefix="",
370
+ end_of_word_suffix="",
371
+ fuse_unk=False,
372
+ )
373
+ )
374
+
375
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
376
+ tokenizer.decoder = decoders.ByteLevel()
377
+ tokenizer.post_processor = processors.RobertaProcessing(
378
+ sep=(ot.sep_token, ot.sep_token_id),
379
+ cls=(ot.cls_token, ot.cls_token_id),
380
+ add_prefix_space=ot.add_prefix_space,
381
+ trim_offsets=True, # True by default on Roberta (historical)
382
+ )
383
+
384
+ return tokenizer
385
+
386
+
387
+ class RoFormerConverter(Converter):
388
+ def converted(self) -> Tokenizer:
389
+ from .models.roformer.tokenization_utils import JiebaPreTokenizer
390
+
391
+ vocab = self.original_tokenizer.vocab
392
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
393
+
394
+ strip_accents = False
395
+ do_lower_case = False
396
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
397
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
398
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
399
+
400
+ tokenizer.normalizer = normalizers.BertNormalizer(
401
+ clean_text=True,
402
+ handle_chinese_chars=False,
403
+ strip_accents=strip_accents,
404
+ lowercase=do_lower_case,
405
+ )
406
+ tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JiebaPreTokenizer(vocab))
407
+
408
+ cls = str(self.original_tokenizer.cls_token)
409
+ sep = str(self.original_tokenizer.sep_token)
410
+ cls_token_id = self.original_tokenizer.cls_token_id
411
+ sep_token_id = self.original_tokenizer.sep_token_id
412
+
413
+ tokenizer.post_processor = processors.TemplateProcessing(
414
+ single=f"{cls}:0 $A:0 {sep}:0",
415
+ pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
416
+ special_tokens=[
417
+ (cls, cls_token_id),
418
+ (sep, sep_token_id),
419
+ ],
420
+ )
421
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
422
+
423
+ return tokenizer
424
+
425
+
426
+ class DebertaConverter(Converter):
427
+ def converted(self) -> Tokenizer:
428
+ ot = self.original_tokenizer
429
+ vocab = ot.encoder
430
+ merges = list(ot.bpe_ranks.keys())
431
+
432
+ tokenizer = Tokenizer(
433
+ BPE(
434
+ vocab=vocab,
435
+ merges=merges,
436
+ dropout=None,
437
+ continuing_subword_prefix="",
438
+ end_of_word_suffix="",
439
+ fuse_unk=False,
440
+ )
441
+ )
442
+
443
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
444
+ tokenizer.decoder = decoders.ByteLevel()
445
+ tokenizer.post_processor = processors.TemplateProcessing(
446
+ single="[CLS]:0 $A:0 [SEP]:0",
447
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
448
+ special_tokens=[
449
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
450
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
451
+ ],
452
+ )
453
+
454
+ return tokenizer
455
+
456
+
457
+ class SpmConverter(Converter):
458
+ def __init__(self, *args):
459
+ requires_backends(self, "protobuf")
460
+
461
+ super().__init__(*args)
462
+
463
+ # from .utils import sentencepiece_model_pb2 as model_pb2
464
+ model_pb2 = import_protobuf()
465
+
466
+ m = model_pb2.ModelProto()
467
+ with open(self.original_tokenizer.vocab_file, "rb") as f:
468
+ m.ParseFromString(f.read())
469
+ self.proto = m
470
+
471
+ if self.proto.trainer_spec.byte_fallback:
472
+ if not getattr(self, "handle_byte_fallback", None):
473
+ warnings.warn(
474
+ "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
475
+ " which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
476
+ " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
477
+ "unknown tokens into a sequence of byte tokens matching the original piece of text."
478
+ )
479
+
480
+ def vocab(self, proto):
481
+ return [(piece.piece, piece.score) for piece in proto.pieces]
482
+
483
+ def unk_id(self, proto):
484
+ return proto.trainer_spec.unk_id
485
+
486
+ def tokenizer(self, proto):
487
+ model_type = proto.trainer_spec.model_type
488
+ vocab_scores = self.vocab(proto)
489
+ unk_id = self.unk_id(proto)
490
+
491
+ if model_type == 1:
492
+ tokenizer = Tokenizer(Unigram(vocab_scores, unk_id))
493
+ elif model_type == 2:
494
+ _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
495
+ bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
496
+ tokenizer = Tokenizer(
497
+ BPE(
498
+ bpe_vocab,
499
+ merges,
500
+ unk_token=proto.trainer_spec.unk_piece,
501
+ fuse_unk=True,
502
+ )
503
+ )
504
+ else:
505
+ raise Exception(
506
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
507
+ )
508
+
509
+ return tokenizer
510
+
511
+ def normalizer(self, proto):
512
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
513
+ if not precompiled_charsmap:
514
+ return normalizers.Sequence([normalizers.Replace(Regex(" {2,}"), " ")])
515
+ else:
516
+ return normalizers.Sequence(
517
+ [normalizers.Precompiled(precompiled_charsmap), normalizers.Replace(Regex(" {2,}"), " ")]
518
+ )
519
+
520
+ def pre_tokenizer(self, replacement, add_prefix_space):
521
+ return pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
522
+
523
+ def post_processor(self):
524
+ return None
525
+
526
+ def decoder(self, replacement, add_prefix_space):
527
+ return decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
528
+
529
+ def converted(self) -> Tokenizer:
530
+ tokenizer = self.tokenizer(self.proto)
531
+
532
+ # Tokenizer assemble
533
+ normalizer = self.normalizer(self.proto)
534
+ if normalizer is not None:
535
+ tokenizer.normalizer = normalizer
536
+
537
+ replacement = "▁"
538
+ add_prefix_space = True
539
+ pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
540
+ if pre_tokenizer is not None:
541
+ tokenizer.pre_tokenizer = pre_tokenizer
542
+
543
+ tokenizer.decoder = self.decoder(replacement, add_prefix_space)
544
+ post_processor = self.post_processor()
545
+ if post_processor:
546
+ tokenizer.post_processor = post_processor
547
+
548
+ return tokenizer
549
+
550
+
551
+ class AlbertConverter(SpmConverter):
552
+ def vocab(self, proto):
553
+ return [
554
+ (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
555
+ for piece in proto.pieces
556
+ ]
557
+
558
+ def normalizer(self, proto):
559
+ list_normalizers = [
560
+ normalizers.Replace("``", '"'),
561
+ normalizers.Replace("''", '"'),
562
+ ]
563
+ if not self.original_tokenizer.keep_accents:
564
+ list_normalizers.append(normalizers.NFKD())
565
+ list_normalizers.append(normalizers.StripAccents())
566
+ if self.original_tokenizer.do_lower_case:
567
+ list_normalizers.append(normalizers.Lowercase())
568
+
569
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
570
+
571
+ if precompiled_charsmap:
572
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
573
+
574
+ list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
575
+ return normalizers.Sequence(list_normalizers)
576
+
577
+ def post_processor(self):
578
+ return processors.TemplateProcessing(
579
+ single="[CLS]:0 $A:0 [SEP]:0",
580
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
581
+ special_tokens=[
582
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
583
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
584
+ ],
585
+ )
586
+
587
+
588
+ class BarthezConverter(SpmConverter):
589
+ def unk_id(self, proto):
590
+ unk_id = 3
591
+ return unk_id
592
+
593
+ def post_processor(self):
594
+ return processors.TemplateProcessing(
595
+ single="<s> $A </s>",
596
+ pair="<s> $A </s> </s> $B </s>",
597
+ special_tokens=[
598
+ ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
599
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
600
+ ],
601
+ )
602
+
603
+
604
+ class CamembertConverter(SpmConverter):
605
+ def vocab(self, proto):
606
+ vocab = [
607
+ ("<s>NOTUSED", 0.0),
608
+ ("<pad>", 0.0),
609
+ ("</s>NOTUSED", 0.0),
610
+ ("<unk>", 0.0),
611
+ ("<unk>NOTUSED", -100),
612
+ ]
613
+ # We down-grade the original SentencePiece by -100 to avoid using it and use our added token instead
614
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[1:]]
615
+ vocab += [("<mask>", 0.0)]
616
+ return vocab
617
+
618
+ def unk_id(self, proto):
619
+ # See vocab unk position
620
+ return 3
621
+
622
+ def post_processor(self):
623
+ return processors.TemplateProcessing(
624
+ single="<s> $A </s>",
625
+ pair="<s> $A </s> </s> $B </s>",
626
+ special_tokens=[
627
+ ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
628
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
629
+ ],
630
+ )
631
+
632
+
633
+ class DebertaV2Converter(SpmConverter):
634
+ def pre_tokenizer(self, replacement, add_prefix_space):
635
+ list_pretokenizers = []
636
+ if self.original_tokenizer.split_by_punct:
637
+ list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
638
+ list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space))
639
+ return pre_tokenizers.Sequence(list_pretokenizers)
640
+
641
+ def normalizer(self, proto):
642
+ list_normalizers = []
643
+ if self.original_tokenizer.do_lower_case:
644
+ list_normalizers.append(normalizers.Lowercase())
645
+ list_normalizers.append(normalizers.Strip())
646
+
647
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
648
+ if precompiled_charsmap:
649
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
650
+ list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
651
+
652
+ return normalizers.Sequence(list_normalizers)
653
+
654
+ def post_processor(self):
655
+ return processors.TemplateProcessing(
656
+ single="[CLS]:0 $A:0 [SEP]:0",
657
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
658
+ special_tokens=[
659
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
660
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
661
+ ],
662
+ )
663
+
664
+
665
+ class MBartConverter(SpmConverter):
666
+ def vocab(self, proto):
667
+ vocab = [
668
+ ("<s>", 0.0),
669
+ ("<pad>", 0.0),
670
+ ("</s>", 0.0),
671
+ ("<unk>", 0.0),
672
+ ]
673
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
674
+ vocab += [
675
+ ("ar_AR", 0.0),
676
+ ("cs_CZ", 0.0),
677
+ ("de_DE", 0.0),
678
+ ("en_XX", 0.0),
679
+ ("es_XX", 0.0),
680
+ ("et_EE", 0.0),
681
+ ("fi_FI", 0.0),
682
+ ("fr_XX", 0.0),
683
+ ("gu_IN", 0.0),
684
+ ("hi_IN", 0.0),
685
+ ("it_IT", 0.0),
686
+ ("ja_XX", 0.0),
687
+ ("kk_KZ", 0.0),
688
+ ("ko_KR", 0.0),
689
+ ("lt_LT", 0.0),
690
+ ("lv_LV", 0.0),
691
+ ("my_MM", 0.0),
692
+ ("ne_NP", 0.0),
693
+ ("nl_XX", 0.0),
694
+ ("ro_RO", 0.0),
695
+ ("ru_RU", 0.0),
696
+ ("si_LK", 0.0),
697
+ ("tr_TR", 0.0),
698
+ ("vi_VN", 0.0),
699
+ ("zh_CN", 0.0),
700
+ ]
701
+ vocab += [("<mask>", 0.0)]
702
+ return vocab
703
+
704
+ def unk_id(self, proto):
705
+ return 3
706
+
707
+ def post_processor(self):
708
+ return processors.TemplateProcessing(
709
+ single="$A </s> en_XX",
710
+ pair="$A $B </s> en_XX",
711
+ special_tokens=[
712
+ ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
713
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
714
+ ],
715
+ )
716
+
717
+
718
+ class MBart50Converter(SpmConverter):
719
+ def vocab(self, proto):
720
+ vocab = [
721
+ ("<s>", 0.0),
722
+ ("<pad>", 0.0),
723
+ ("</s>", 0.0),
724
+ ("<unk>", 0.0),
725
+ ]
726
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
727
+ # fmt: off
728
+ vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)]
729
+ # fmt: on
730
+ vocab += [("<mask>", 0.0)]
731
+ return vocab
732
+
733
+ def unk_id(self, proto):
734
+ return 3
735
+
736
+ def post_processor(self):
737
+ return processors.TemplateProcessing(
738
+ single="en_XX $A </s>",
739
+ pair="en_XX $A $B </s>",
740
+ special_tokens=[
741
+ ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
742
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
743
+ ],
744
+ )
745
+
746
+
747
+ class NllbConverter(SpmConverter):
748
+ def vocab(self, proto):
749
+ vocab = [
750
+ ("<s>", 0.0),
751
+ ("<pad>", 0.0),
752
+ ("</s>", 0.0),
753
+ ("<unk>", 0.0),
754
+ ]
755
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
756
+ vocab += [
757
+ # fmt: off
758
+ ('ace_Arab', 0.0), ('ace_Latn', 0.0), ('acm_Arab', 0.0), ('acq_Arab', 0.0), ('aeb_Arab', 0.0), ('afr_Latn', 0.0), ('ajp_Arab', 0.0), ('aka_Latn', 0.0), ('amh_Ethi', 0.0), ('apc_Arab', 0.0), ('arb_Arab', 0.0), ('ars_Arab', 0.0), ('ary_Arab', 0.0), ('arz_Arab', 0.0), ('asm_Beng', 0.0), ('ast_Latn', 0.0), ('awa_Deva', 0.0), ('ayr_Latn', 0.0), ('azb_Arab', 0.0), ('azj_Latn', 0.0), ('bak_Cyrl', 0.0), ('bam_Latn', 0.0), ('ban_Latn', 0.0), ('bel_Cyrl', 0.0), ('bem_Latn', 0.0), ('ben_Beng', 0.0), ('bho_Deva', 0.0), ('bjn_Arab', 0.0), ('bjn_Latn', 0.0), ('bod_Tibt', 0.0), ('bos_Latn', 0.0), ('bug_Latn', 0.0), ('bul_Cyrl', 0.0), ('cat_Latn', 0.0), ('ceb_Latn', 0.0), ('ces_Latn', 0.0), ('cjk_Latn', 0.0), ('ckb_Arab', 0.0), ('crh_Latn', 0.0), ('cym_Latn', 0.0), ('dan_Latn', 0.0), ('deu_Latn', 0.0), ('dik_Latn', 0.0), ('dyu_Latn', 0.0), ('dzo_Tibt', 0.0), ('ell_Grek', 0.0), ('eng_Latn', 0.0), ('epo_Latn', 0.0), ('est_Latn', 0.0), ('eus_Latn', 0.0), ('ewe_Latn', 0.0), ('fao_Latn', 0.0), ('pes_Arab', 0.0), ('fij_Latn', 0.0), ('fin_Latn', 0.0), ('fon_Latn', 0.0), ('fra_Latn', 0.0), ('fur_Latn', 0.0), ('fuv_Latn', 0.0), ('gla_Latn', 0.0), ('gle_Latn', 0.0), ('glg_Latn', 0.0), ('grn_Latn', 0.0), ('guj_Gujr', 0.0), ('hat_Latn', 0.0), ('hau_Latn', 0.0), ('heb_Hebr', 0.0), ('hin_Deva', 0.0), ('hne_Deva', 0.0), ('hrv_Latn', 0.0), ('hun_Latn', 0.0), ('hye_Armn', 0.0), ('ibo_Latn', 0.0), ('ilo_Latn', 0.0), ('ind_Latn', 0.0), ('isl_Latn', 0.0), ('ita_Latn', 0.0), ('jav_Latn', 0.0), ('jpn_Jpan', 0.0), ('kab_Latn', 0.0), ('kac_Latn', 0.0), ('kam_Latn', 0.0), ('kan_Knda', 0.0), ('kas_Arab', 0.0), ('kas_Deva', 0.0), ('kat_Geor', 0.0), ('knc_Arab', 0.0), ('knc_Latn', 0.0), ('kaz_Cyrl', 0.0), ('kbp_Latn', 0.0), ('kea_Latn', 0.0), ('khm_Khmr', 0.0), ('kik_Latn', 0.0), ('kin_Latn', 0.0), ('kir_Cyrl', 0.0), ('kmb_Latn', 0.0), ('kon_Latn', 0.0), ('kor_Hang', 0.0), ('kmr_Latn', 0.0), ('lao_Laoo', 0.0), ('lvs_Latn', 0.0), ('lij_Latn', 0.0), ('lim_Latn', 0.0), ('lin_Latn', 0.0), ('lit_Latn', 0.0), ('lmo_Latn', 0.0), ('ltg_Latn', 0.0), ('ltz_Latn', 0.0), ('lua_Latn', 0.0), ('lug_Latn', 0.0), ('luo_Latn', 0.0), ('lus_Latn', 0.0), ('mag_Deva', 0.0), ('mai_Deva', 0.0), ('mal_Mlym', 0.0), ('mar_Deva', 0.0), ('min_Latn', 0.0), ('mkd_Cyrl', 0.0), ('plt_Latn', 0.0), ('mlt_Latn', 0.0), ('mni_Beng', 0.0), ('khk_Cyrl', 0.0), ('mos_Latn', 0.0), ('mri_Latn', 0.0), ('zsm_Latn', 0.0), ('mya_Mymr', 0.0), ('nld_Latn', 0.0), ('nno_Latn', 0.0), ('nob_Latn', 0.0), ('npi_Deva', 0.0), ('nso_Latn', 0.0), ('nus_Latn', 0.0), ('nya_Latn', 0.0), ('oci_Latn', 0.0), ('gaz_Latn', 0.0), ('ory_Orya', 0.0), ('pag_Latn', 0.0), ('pan_Guru', 0.0), ('pap_Latn', 0.0), ('pol_Latn', 0.0), ('por_Latn', 0.0), ('prs_Arab', 0.0), ('pbt_Arab', 0.0), ('quy_Latn', 0.0), ('ron_Latn', 0.0), ('run_Latn', 0.0), ('rus_Cyrl', 0.0), ('sag_Latn', 0.0), ('san_Deva', 0.0), ('sat_Beng', 0.0), ('scn_Latn', 0.0), ('shn_Mymr', 0.0), ('sin_Sinh', 0.0), ('slk_Latn', 0.0), ('slv_Latn', 0.0), ('smo_Latn', 0.0), ('sna_Latn', 0.0), ('snd_Arab', 0.0), ('som_Latn', 0.0), ('sot_Latn', 0.0), ('spa_Latn', 0.0), ('als_Latn', 0.0), ('srd_Latn', 0.0), ('srp_Cyrl', 0.0), ('ssw_Latn', 0.0), ('sun_Latn', 0.0), ('swe_Latn', 0.0), ('swh_Latn', 0.0), ('szl_Latn', 0.0), ('tam_Taml', 0.0), ('tat_Cyrl', 0.0), ('tel_Telu', 0.0), ('tgk_Cyrl', 0.0), ('tgl_Latn', 0.0), ('tha_Thai', 0.0), ('tir_Ethi', 0.0), ('taq_Latn', 0.0), ('taq_Tfng', 0.0), ('tpi_Latn', 0.0), ('tsn_Latn', 0.0), ('tso_Latn', 0.0), ('tuk_Latn', 0.0), ('tum_Latn', 0.0), ('tur_Latn', 0.0), ('twi_Latn', 0.0), ('tzm_Tfng', 0.0), ('uig_Arab', 0.0), ('ukr_Cyrl', 0.0), ('umb_Latn', 0.0), ('urd_Arab', 0.0), ('uzn_Latn', 0.0), ('vec_Latn', 0.0), ('vie_Latn', 0.0), ('war_Latn', 0.0), ('wol_Latn', 0.0), ('xho_Latn', 0.0), ('ydd_Hebr', 0.0), ('yor_Latn', 0.0), ('yue_Hant', 0.0), ('zho_Hans', 0.0), ('zho_Hant', 0.0), ('zul_Latn', 0.0)
759
+ # fmt: on
760
+ ]
761
+ vocab += [("<mask>", 0.0)]
762
+ return vocab
763
+
764
+ def unk_id(self, proto):
765
+ return 3
766
+
767
+ def post_processor(self):
768
+ return processors.TemplateProcessing(
769
+ single="eng_Latn $A </s>",
770
+ pair="eng_Latn $A $B </s>",
771
+ special_tokens=[
772
+ ("eng_Latn", self.original_tokenizer.convert_tokens_to_ids("eng_Latn")),
773
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
774
+ ],
775
+ )
776
+
777
+
778
+ class XLMRobertaConverter(SpmConverter):
779
+ def vocab(self, proto):
780
+ vocab = [
781
+ ("<s>", 0.0),
782
+ ("<pad>", 0.0),
783
+ ("</s>", 0.0),
784
+ ("<unk>", 0.0),
785
+ ]
786
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
787
+ vocab += [("<mask>", 0.0)]
788
+ return vocab
789
+
790
+ def unk_id(self, proto):
791
+ unk_id = 3
792
+ return unk_id
793
+
794
+ def post_processor(self):
795
+ return processors.TemplateProcessing(
796
+ single="<s> $A </s>",
797
+ pair="<s> $A </s> </s> $B </s>",
798
+ special_tokens=[
799
+ ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
800
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
801
+ ],
802
+ )
803
+
804
+
805
+ class XLNetConverter(SpmConverter):
806
+ def vocab(self, proto):
807
+ return [
808
+ (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
809
+ for piece in proto.pieces
810
+ ]
811
+
812
+ def normalizer(self, proto):
813
+ list_normalizers = [
814
+ normalizers.Replace("``", '"'),
815
+ normalizers.Replace("''", '"'),
816
+ ]
817
+ if not self.original_tokenizer.keep_accents:
818
+ list_normalizers.append(normalizers.NFKD())
819
+ list_normalizers.append(normalizers.StripAccents())
820
+ if self.original_tokenizer.do_lower_case:
821
+ list_normalizers.append(normalizers.Lowercase())
822
+
823
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
824
+
825
+ if precompiled_charsmap:
826
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
827
+
828
+ list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
829
+ return normalizers.Sequence(list_normalizers)
830
+
831
+ def post_processor(self):
832
+ return processors.TemplateProcessing(
833
+ single="$A:0 <sep>:0 <cls>:2",
834
+ pair="$A:0 <sep>:0 $B:1 <sep>:1 <cls>:2",
835
+ special_tokens=[
836
+ ("<sep>", self.original_tokenizer.convert_tokens_to_ids("<sep>")),
837
+ ("<cls>", self.original_tokenizer.convert_tokens_to_ids("<cls>")),
838
+ ],
839
+ )
840
+
841
+
842
+ class ReformerConverter(SpmConverter):
843
+ pass
844
+
845
+
846
+ class RemBertConverter(SpmConverter):
847
+ # Inspired from AlbertConverter
848
+ def normalizer(self, proto):
849
+ list_normalizers = [
850
+ normalizers.Replace("``", '"'),
851
+ normalizers.Replace("''", '"'),
852
+ normalizers.Replace(Regex(" {2,}"), " "),
853
+ ]
854
+ if not self.original_tokenizer.keep_accents:
855
+ list_normalizers.append(normalizers.NFKD())
856
+ list_normalizers.append(normalizers.StripAccents())
857
+ if self.original_tokenizer.do_lower_case:
858
+ list_normalizers.append(normalizers.Lowercase())
859
+
860
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
861
+
862
+ if precompiled_charsmap:
863
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
864
+
865
+ return normalizers.Sequence(list_normalizers)
866
+
867
+ def post_processor(self):
868
+ return processors.TemplateProcessing(
869
+ single="[CLS]:0 $A:0 [SEP]:0",
870
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
871
+ special_tokens=[
872
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
873
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
874
+ ],
875
+ )
876
+
877
+
878
+ class BertGenerationConverter(SpmConverter):
879
+ pass
880
+
881
+
882
+ class PegasusConverter(SpmConverter):
883
+ def vocab(self, proto):
884
+ vocab = [
885
+ (self.original_tokenizer.pad_token, 0.0),
886
+ (self.original_tokenizer.eos_token, 0.0),
887
+ ]
888
+
889
+ if self.original_tokenizer.mask_token_sent is not None:
890
+ vocab += [(self.original_tokenizer.mask_token_sent, 0.0)]
891
+
892
+ if (
893
+ self.original_tokenizer.mask_token is not None
894
+ and self.original_tokenizer.mask_token_id < self.original_tokenizer.offset
895
+ ):
896
+ vocab += [(self.original_tokenizer.mask_token, 0.0)]
897
+
898
+ vocab += [(f"<unk_{i}>", -100.0) for i in range(2, self.original_tokenizer.offset)]
899
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
900
+ return vocab
901
+
902
+ def unk_id(self, proto):
903
+ return proto.trainer_spec.unk_id + self.original_tokenizer.offset
904
+
905
+ def pre_tokenizer(self, replacement, add_prefix_space):
906
+ return pre_tokenizers.Sequence(
907
+ [
908
+ pre_tokenizers.WhitespaceSplit(),
909
+ pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
910
+ ]
911
+ )
912
+
913
+ def post_processor(self):
914
+ eos = self.original_tokenizer.eos_token
915
+ special_tokens = [
916
+ (eos, self.original_tokenizer.eos_token_id),
917
+ ]
918
+ return processors.TemplateProcessing(single=["$A", eos], pair=["$A", "$B", eos], special_tokens=special_tokens)
919
+
920
+
921
+ class T5Converter(SpmConverter):
922
+ def vocab(self, proto):
923
+ num_extra_ids = self.original_tokenizer._extra_ids
924
+ vocab = [(piece.piece, piece.score) for piece in proto.pieces]
925
+ vocab += [(f"<extra_id_{i}>", 0.0) for i in range(num_extra_ids - 1, -1, -1)]
926
+ return vocab
927
+
928
+ def post_processor(self):
929
+ return processors.TemplateProcessing(
930
+ single=["$A", "</s>"],
931
+ pair=["$A", "</s>", "$B", "</s>"],
932
+ special_tokens=[
933
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
934
+ ],
935
+ )
936
+
937
+
938
+ class WhisperConverter(Converter):
939
+ def converted(self) -> Tokenizer:
940
+ vocab = self.original_tokenizer.encoder
941
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
942
+
943
+ tokenizer = Tokenizer(
944
+ BPE(
945
+ vocab=vocab,
946
+ merges=merges,
947
+ dropout=None,
948
+ continuing_subword_prefix="",
949
+ end_of_word_suffix="",
950
+ fuse_unk=False,
951
+ )
952
+ )
953
+
954
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
955
+ tokenizer.decoder = decoders.ByteLevel()
956
+
957
+ prefix_token_ids = self.original_tokenizer.prefix_tokens
958
+ prefixes = self.original_tokenizer.convert_ids_to_tokens(prefix_token_ids)
959
+ eos = self.original_tokenizer.eos_token
960
+ eos_token_id = self.original_tokenizer.eos_token_id
961
+ prefix_template = " ".join([f"{token}:0" for token in prefixes])
962
+ tokenizer.post_processor = processors.TemplateProcessing(
963
+ single=f"{prefix_template} $A:0 {eos}:0",
964
+ pair=f"{prefix_template} $A:0 $B:1 {eos}:1",
965
+ special_tokens=[
966
+ (eos, eos_token_id),
967
+ *zip(prefixes, prefix_token_ids),
968
+ ],
969
+ )
970
+
971
+ return tokenizer
972
+
973
+
974
+ class BigBirdConverter(SpmConverter):
975
+ def post_processor(self):
976
+ return processors.TemplateProcessing(
977
+ single="[CLS]:0 $A:0 [SEP]:0",
978
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
979
+ special_tokens=[
980
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
981
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
982
+ ],
983
+ )
984
+
985
+
986
+ class CLIPConverter(Converter):
987
+ def converted(self) -> Tokenizer:
988
+ vocab = self.original_tokenizer.encoder
989
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
990
+ unk_token = self.original_tokenizer.unk_token
991
+
992
+ tokenizer = Tokenizer(
993
+ BPE(
994
+ vocab=vocab,
995
+ merges=merges,
996
+ dropout=None,
997
+ continuing_subword_prefix="",
998
+ end_of_word_suffix="</w>",
999
+ fuse_unk=False,
1000
+ unk_token=str(unk_token),
1001
+ )
1002
+ )
1003
+
1004
+ tokenizer.normalizer = normalizers.Sequence(
1005
+ [normalizers.NFC(), normalizers.Replace(Regex(r"\s+"), " "), normalizers.Lowercase()]
1006
+ )
1007
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
1008
+ [
1009
+ pre_tokenizers.Split(
1010
+ Regex(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""),
1011
+ behavior="removed",
1012
+ invert=True,
1013
+ ),
1014
+ pre_tokenizers.ByteLevel(add_prefix_space=False),
1015
+ ]
1016
+ )
1017
+ tokenizer.decoder = decoders.ByteLevel()
1018
+
1019
+ # Hack to have a ByteLevel and TemplaceProcessor
1020
+ tokenizer.post_processor = processors.RobertaProcessing(
1021
+ sep=(self.original_tokenizer.eos_token, self.original_tokenizer.eos_token_id),
1022
+ cls=(self.original_tokenizer.bos_token, self.original_tokenizer.bos_token_id),
1023
+ add_prefix_space=False,
1024
+ trim_offsets=False,
1025
+ )
1026
+ return tokenizer
1027
+
1028
+
1029
+ class LayoutLMv2Converter(Converter):
1030
+ def converted(self) -> Tokenizer:
1031
+ vocab = self.original_tokenizer.vocab
1032
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
1033
+
1034
+ tokenize_chinese_chars = False
1035
+ strip_accents = False
1036
+ do_lower_case = True
1037
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
1038
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
1039
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
1040
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
1041
+
1042
+ tokenizer.normalizer = normalizers.BertNormalizer(
1043
+ clean_text=True,
1044
+ handle_chinese_chars=tokenize_chinese_chars,
1045
+ strip_accents=strip_accents,
1046
+ lowercase=do_lower_case,
1047
+ )
1048
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
1049
+
1050
+ cls = str(self.original_tokenizer.cls_token)
1051
+ sep = str(self.original_tokenizer.sep_token)
1052
+ cls_token_id = self.original_tokenizer.cls_token_id
1053
+ sep_token_id = self.original_tokenizer.sep_token_id
1054
+
1055
+ tokenizer.post_processor = processors.TemplateProcessing(
1056
+ single=f"{cls}:0 $A:0 {sep}:0",
1057
+ pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
1058
+ special_tokens=[
1059
+ (cls, cls_token_id),
1060
+ (sep, sep_token_id),
1061
+ ],
1062
+ )
1063
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
1064
+
1065
+ return tokenizer
1066
+
1067
+
1068
+ class BlenderbotConverter(Converter):
1069
+ def converted(self) -> Tokenizer:
1070
+ ot = self.original_tokenizer
1071
+ vocab = ot.encoder
1072
+ merges = list(ot.bpe_ranks.keys())
1073
+
1074
+ tokenizer = Tokenizer(
1075
+ BPE(
1076
+ vocab=vocab,
1077
+ merges=merges,
1078
+ dropout=None,
1079
+ continuing_subword_prefix="",
1080
+ end_of_word_suffix="",
1081
+ fuse_unk=False,
1082
+ )
1083
+ )
1084
+
1085
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
1086
+ tokenizer.decoder = decoders.ByteLevel()
1087
+ tokenizer.post_processor = processors.TemplateProcessing(
1088
+ single=f"$A:0 {ot.eos_token}:0",
1089
+ special_tokens=[
1090
+ (ot.eos_token, ot.eos_token_id),
1091
+ ],
1092
+ )
1093
+
1094
+ return tokenizer
1095
+
1096
+
1097
+ class XGLMConverter(SpmConverter):
1098
+ def vocab(self, proto):
1099
+ vocab = [
1100
+ ("<s>", 0.0),
1101
+ ("<pad>", 0.0),
1102
+ ("</s>", 0.0),
1103
+ ("<unk>", 0.0),
1104
+ ]
1105
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
1106
+ # fmt: off
1107
+ vocab += [("<madeupword0>", 0.0), ("<madeupword1>", 0.0), ("<madeupword2>", 0.0), ("<madeupword3>", 0.0), ("<madeupword4>", 0.0), ("<madeupword5>", 0.0), ("<madeupword6>", 0.0)]
1108
+ # fmt: on
1109
+ return vocab
1110
+
1111
+ def unk_id(self, proto):
1112
+ unk_id = 3
1113
+ return unk_id
1114
+
1115
+ def post_processor(self):
1116
+ return processors.TemplateProcessing(
1117
+ single="</s> $A",
1118
+ pair="</s> $A </s> </s> $B",
1119
+ special_tokens=[
1120
+ ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
1121
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
1122
+ ],
1123
+ )
1124
+
1125
+
1126
+ class LlamaConverter(SpmConverter):
1127
+ handle_byte_fallback = True
1128
+
1129
+ def vocab(self, proto):
1130
+ vocab = [
1131
+ ("<unk>", 0.0),
1132
+ ("<s>", 0.0),
1133
+ ("</s>", 0.0),
1134
+ ]
1135
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
1136
+ return vocab
1137
+
1138
+ def unk_id(self, proto):
1139
+ unk_id = 0
1140
+ return unk_id
1141
+
1142
+ def decoder(self, replacement, add_prefix_space):
1143
+ return decoders.Sequence(
1144
+ [
1145
+ decoders.Replace("▁", " "),
1146
+ decoders.ByteFallback(),
1147
+ decoders.Fuse(),
1148
+ decoders.Strip(content=" ", left=1),
1149
+ ]
1150
+ )
1151
+
1152
+ def tokenizer(self, proto):
1153
+ model_type = proto.trainer_spec.model_type
1154
+ vocab_scores = self.vocab(proto)
1155
+ if model_type == 1:
1156
+ import tokenizers
1157
+
1158
+ if version.parse(tokenizers.__version__) < version.parse("0.14.0"):
1159
+ tokenizer = Tokenizer(Unigram(vocab_scores, 0))
1160
+ else:
1161
+ tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True))
1162
+
1163
+ elif model_type == 2:
1164
+ _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
1165
+ bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
1166
+ tokenizer = Tokenizer(
1167
+ BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
1168
+ )
1169
+ tokenizer.add_special_tokens(
1170
+ [
1171
+ AddedToken("<unk>"),
1172
+ AddedToken("<s>"),
1173
+ AddedToken("</s>"),
1174
+ ]
1175
+ )
1176
+ else:
1177
+ raise Exception(
1178
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
1179
+ )
1180
+
1181
+ return tokenizer
1182
+
1183
+ def normalizer(self, proto):
1184
+ return normalizers.Sequence(
1185
+ [
1186
+ normalizers.Prepend(prepend="▁"),
1187
+ normalizers.Replace(pattern=" ", content="▁"),
1188
+ ]
1189
+ )
1190
+
1191
+ def pre_tokenizer(self, replacement, add_prefix_space):
1192
+ return None
1193
+
1194
+ def post_processor(self):
1195
+ # the processor is defined in the LlamaTokenizerFast class.
1196
+ return None
1197
+
1198
+
1199
+ class MarkupLMConverter(Converter):
1200
+ def converted(self) -> Tokenizer:
1201
+ ot = self.original_tokenizer
1202
+ vocab = ot.encoder
1203
+ merges = list(ot.bpe_ranks.keys())
1204
+
1205
+ tokenizer = Tokenizer(
1206
+ BPE(
1207
+ vocab=vocab,
1208
+ merges=merges,
1209
+ dropout=None,
1210
+ continuing_subword_prefix="",
1211
+ end_of_word_suffix="",
1212
+ fuse_unk=False,
1213
+ unk_token=self.original_tokenizer.unk_token,
1214
+ )
1215
+ )
1216
+
1217
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
1218
+ tokenizer.decoder = decoders.ByteLevel()
1219
+
1220
+ cls = str(self.original_tokenizer.cls_token)
1221
+ sep = str(self.original_tokenizer.sep_token)
1222
+ cls_token_id = self.original_tokenizer.cls_token_id
1223
+ sep_token_id = self.original_tokenizer.sep_token_id
1224
+
1225
+ tokenizer.post_processor = processors.TemplateProcessing(
1226
+ single=f"{cls} $A {sep}",
1227
+ pair=f"{cls} $A {sep} $B {sep}",
1228
+ special_tokens=[
1229
+ (cls, cls_token_id),
1230
+ (sep, sep_token_id),
1231
+ ],
1232
+ )
1233
+
1234
+ return tokenizer
1235
+
1236
+
1237
+ SLOW_TO_FAST_CONVERTERS = {
1238
+ "AlbertTokenizer": AlbertConverter,
1239
+ "BartTokenizer": RobertaConverter,
1240
+ "BarthezTokenizer": BarthezConverter,
1241
+ "BertTokenizer": BertConverter,
1242
+ "BigBirdTokenizer": BigBirdConverter,
1243
+ "BlenderbotTokenizer": BlenderbotConverter,
1244
+ "CamembertTokenizer": CamembertConverter,
1245
+ "CLIPTokenizer": CLIPConverter,
1246
+ "CodeGenTokenizer": GPT2Converter,
1247
+ "ConvBertTokenizer": BertConverter,
1248
+ "DebertaTokenizer": DebertaConverter,
1249
+ "DebertaV2Tokenizer": DebertaV2Converter,
1250
+ "DistilBertTokenizer": BertConverter,
1251
+ "DPRReaderTokenizer": BertConverter,
1252
+ "DPRQuestionEncoderTokenizer": BertConverter,
1253
+ "DPRContextEncoderTokenizer": BertConverter,
1254
+ "ElectraTokenizer": BertConverter,
1255
+ "FNetTokenizer": AlbertConverter,
1256
+ "FunnelTokenizer": FunnelConverter,
1257
+ "GPT2Tokenizer": GPT2Converter,
1258
+ "HerbertTokenizer": HerbertConverter,
1259
+ "LayoutLMTokenizer": BertConverter,
1260
+ "LayoutLMv2Tokenizer": BertConverter,
1261
+ "LayoutLMv3Tokenizer": RobertaConverter,
1262
+ "LayoutXLMTokenizer": XLMRobertaConverter,
1263
+ "LongformerTokenizer": RobertaConverter,
1264
+ "LEDTokenizer": RobertaConverter,
1265
+ "LxmertTokenizer": BertConverter,
1266
+ "MarkupLMTokenizer": MarkupLMConverter,
1267
+ "MBartTokenizer": MBartConverter,
1268
+ "MBart50Tokenizer": MBart50Converter,
1269
+ "MPNetTokenizer": MPNetConverter,
1270
+ "MobileBertTokenizer": BertConverter,
1271
+ "MvpTokenizer": RobertaConverter,
1272
+ "NllbTokenizer": NllbConverter,
1273
+ "OpenAIGPTTokenizer": OpenAIGPTConverter,
1274
+ "PegasusTokenizer": PegasusConverter,
1275
+ "RealmTokenizer": BertConverter,
1276
+ "ReformerTokenizer": ReformerConverter,
1277
+ "RemBertTokenizer": RemBertConverter,
1278
+ "RetriBertTokenizer": BertConverter,
1279
+ "RobertaTokenizer": RobertaConverter,
1280
+ "RoFormerTokenizer": RoFormerConverter,
1281
+ "SqueezeBertTokenizer": BertConverter,
1282
+ "T5Tokenizer": T5Converter,
1283
+ "WhisperTokenizer": WhisperConverter,
1284
+ "XLMRobertaTokenizer": XLMRobertaConverter,
1285
+ "XLNetTokenizer": XLNetConverter,
1286
+ "SplinterTokenizer": SplinterConverter,
1287
+ "XGLMTokenizer": XGLMConverter,
1288
+ "LlamaTokenizer": LlamaConverter,
1289
+ "CodeLlamaTokenizer": LlamaConverter,
1290
+ }
1291
+
1292
+
1293
+ def convert_slow_tokenizer(transformer_tokenizer) -> Tokenizer:
1294
+ """
1295
+ Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
1296
+
1297
+ Args:
1298
+ transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
1299
+ Instance of a slow tokenizer to convert in the backend tokenizer for
1300
+ [`~tokenization_utils_base.PreTrainedTokenizerFast`].
1301
+
1302
+ Return:
1303
+ A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
1304
+ [`~tokenization_utils_base.PreTrainedTokenizerFast`]
1305
+ """
1306
+
1307
+ tokenizer_class_name = transformer_tokenizer.__class__.__name__
1308
+
1309
+ if tokenizer_class_name not in SLOW_TO_FAST_CONVERTERS:
1310
+ raise ValueError(
1311
+ f"An instance of tokenizer class {tokenizer_class_name} cannot be converted in a Fast tokenizer instance."
1312
+ " No converter was found. Currently available slow->fast convertors:"
1313
+ f" {list(SLOW_TO_FAST_CONVERTERS.keys())}"
1314
+ )
1315
+
1316
+ converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
1317
+
1318
+ return converter_class(transformer_tokenizer).converted()
transformers_4_35_0/convert_slow_tokenizers_checkpoints_to_fast.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Convert slow tokenizers checkpoints in fast (serialization format of the `tokenizers` library)"""
16
+
17
+ import argparse
18
+ import os
19
+
20
+ import transformers
21
+
22
+ from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
23
+ from .utils import logging
24
+
25
+
26
+ logging.set_verbosity_info()
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ TOKENIZER_CLASSES = {name: getattr(transformers, name + "Fast") for name in SLOW_TO_FAST_CONVERTERS}
32
+
33
+
34
+ def convert_slow_checkpoint_to_fast(tokenizer_name, checkpoint_name, dump_path, force_download):
35
+ if tokenizer_name is not None and tokenizer_name not in TOKENIZER_CLASSES:
36
+ raise ValueError(f"Unrecognized tokenizer name, should be one of {list(TOKENIZER_CLASSES.keys())}.")
37
+
38
+ if tokenizer_name is None:
39
+ tokenizer_names = TOKENIZER_CLASSES
40
+ else:
41
+ tokenizer_names = {tokenizer_name: getattr(transformers, tokenizer_name + "Fast")}
42
+
43
+ logger.info(f"Loading tokenizer classes: {tokenizer_names}")
44
+
45
+ for tokenizer_name in tokenizer_names:
46
+ tokenizer_class = TOKENIZER_CLASSES[tokenizer_name]
47
+
48
+ add_prefix = True
49
+ if checkpoint_name is None:
50
+ checkpoint_names = list(tokenizer_class.max_model_input_sizes.keys())
51
+ else:
52
+ checkpoint_names = [checkpoint_name]
53
+
54
+ logger.info(f"For tokenizer {tokenizer_class.__class__.__name__} loading checkpoints: {checkpoint_names}")
55
+
56
+ for checkpoint in checkpoint_names:
57
+ logger.info(f"Loading {tokenizer_class.__class__.__name__} {checkpoint}")
58
+
59
+ # Load tokenizer
60
+ tokenizer = tokenizer_class.from_pretrained(checkpoint, force_download=force_download)
61
+
62
+ # Save fast tokenizer
63
+ logger.info(f"Save fast tokenizer to {dump_path} with prefix {checkpoint} add_prefix {add_prefix}")
64
+
65
+ # For organization names we create sub-directories
66
+ if "/" in checkpoint:
67
+ checkpoint_directory, checkpoint_prefix_name = checkpoint.split("/")
68
+ dump_path_full = os.path.join(dump_path, checkpoint_directory)
69
+ elif add_prefix:
70
+ checkpoint_prefix_name = checkpoint
71
+ dump_path_full = dump_path
72
+ else:
73
+ checkpoint_prefix_name = None
74
+ dump_path_full = dump_path
75
+
76
+ logger.info(f"=> {dump_path_full} with prefix {checkpoint_prefix_name}, add_prefix {add_prefix}")
77
+
78
+ if checkpoint in list(tokenizer.pretrained_vocab_files_map.values())[0]:
79
+ file_path = list(tokenizer.pretrained_vocab_files_map.values())[0][checkpoint]
80
+ next_char = file_path.split(checkpoint)[-1][0]
81
+ if next_char == "/":
82
+ dump_path_full = os.path.join(dump_path_full, checkpoint_prefix_name)
83
+ checkpoint_prefix_name = None
84
+
85
+ logger.info(f"=> {dump_path_full} with prefix {checkpoint_prefix_name}, add_prefix {add_prefix}")
86
+
87
+ file_names = tokenizer.save_pretrained(
88
+ dump_path_full, legacy_format=False, filename_prefix=checkpoint_prefix_name
89
+ )
90
+ logger.info(f"=> File names {file_names}")
91
+
92
+ for file_name in file_names:
93
+ if not file_name.endswith("tokenizer.json"):
94
+ os.remove(file_name)
95
+ logger.info(f"=> removing {file_name}")
96
+
97
+
98
+ if __name__ == "__main__":
99
+ parser = argparse.ArgumentParser()
100
+ # Required parameters
101
+ parser.add_argument(
102
+ "--dump_path", default=None, type=str, required=True, help="Path to output generated fast tokenizer files."
103
+ )
104
+ parser.add_argument(
105
+ "--tokenizer_name",
106
+ default=None,
107
+ type=str,
108
+ help=(
109
+ f"Optional tokenizer type selected in the list of {list(TOKENIZER_CLASSES.keys())}. If not given, will "
110
+ "download and convert all the checkpoints from AWS."
111
+ ),
112
+ )
113
+ parser.add_argument(
114
+ "--checkpoint_name",
115
+ default=None,
116
+ type=str,
117
+ help="Optional checkpoint name. If not given, will download and convert the canonical checkpoints from AWS.",
118
+ )
119
+ parser.add_argument(
120
+ "--force_download",
121
+ action="store_true",
122
+ help="Re-download checkpoints.",
123
+ )
124
+ args = parser.parse_args()
125
+
126
+ convert_slow_checkpoint_to_fast(args.tokenizer_name, args.checkpoint_name, args.dump_path, args.force_download)
transformers_4_35_0/convert_tf_hub_seq_to_seq_bert_to_pytorch.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert Seq2Seq TF Hub checkpoint."""
16
+
17
+
18
+ import argparse
19
+
20
+ from . import (
21
+ BertConfig,
22
+ BertGenerationConfig,
23
+ BertGenerationDecoder,
24
+ BertGenerationEncoder,
25
+ load_tf_weights_in_bert_generation,
26
+ logging,
27
+ )
28
+
29
+
30
+ logging.set_verbosity_info()
31
+
32
+
33
+ def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_named_decoder, vocab_size, is_encoder):
34
+ # Initialise PyTorch model
35
+ bert_config = BertConfig.from_pretrained(
36
+ "bert-large-cased",
37
+ vocab_size=vocab_size,
38
+ max_position_embeddings=512,
39
+ is_decoder=True,
40
+ add_cross_attention=True,
41
+ )
42
+ bert_config_dict = bert_config.to_dict()
43
+ del bert_config_dict["type_vocab_size"]
44
+ config = BertGenerationConfig(**bert_config_dict)
45
+ if is_encoder:
46
+ model = BertGenerationEncoder(config)
47
+ else:
48
+ model = BertGenerationDecoder(config)
49
+ print(f"Building PyTorch model from configuration: {config}")
50
+
51
+ # Load weights from tf checkpoint
52
+ load_tf_weights_in_bert_generation(
53
+ model,
54
+ tf_hub_path,
55
+ model_class="bert",
56
+ is_encoder_named_decoder=is_encoder_named_decoder,
57
+ is_encoder=is_encoder,
58
+ )
59
+
60
+ # Save pytorch-model
61
+ print(f"Save PyTorch model and config to {pytorch_dump_path}")
62
+ model.save_pretrained(pytorch_dump_path)
63
+
64
+
65
+ if __name__ == "__main__":
66
+ parser = argparse.ArgumentParser()
67
+ # Required parameters
68
+ parser.add_argument(
69
+ "--tf_hub_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
70
+ )
71
+ parser.add_argument(
72
+ "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
73
+ )
74
+ parser.add_argument(
75
+ "--is_encoder_named_decoder",
76
+ action="store_true",
77
+ help="If decoder has to be renamed to encoder in PyTorch model.",
78
+ )
79
+ parser.add_argument("--is_encoder", action="store_true", help="If model is an encoder.")
80
+ parser.add_argument("--vocab_size", default=50358, type=int, help="Vocab size of model")
81
+ args = parser.parse_args()
82
+ convert_tf_checkpoint_to_pytorch(
83
+ args.tf_hub_path,
84
+ args.pytorch_dump_path,
85
+ args.is_encoder_named_decoder,
86
+ args.vocab_size,
87
+ is_encoder=args.is_encoder,
88
+ )
transformers_4_35_0/data/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .data_collator import (
16
+ DataCollatorForLanguageModeling,
17
+ DataCollatorForPermutationLanguageModeling,
18
+ DataCollatorForSeq2Seq,
19
+ DataCollatorForSOP,
20
+ DataCollatorForTokenClassification,
21
+ DataCollatorForWholeWordMask,
22
+ DataCollatorWithPadding,
23
+ DefaultDataCollator,
24
+ default_data_collator,
25
+ )
26
+ from .metrics import glue_compute_metrics, xnli_compute_metrics
27
+ from .processors import (
28
+ DataProcessor,
29
+ InputExample,
30
+ InputFeatures,
31
+ SingleSentenceClassificationProcessor,
32
+ SquadExample,
33
+ SquadFeatures,
34
+ SquadV1Processor,
35
+ SquadV2Processor,
36
+ glue_convert_examples_to_features,
37
+ glue_output_modes,
38
+ glue_processors,
39
+ glue_tasks_num_labels,
40
+ squad_convert_examples_to_features,
41
+ xnli_output_modes,
42
+ xnli_processors,
43
+ xnli_tasks_num_labels,
44
+ )
transformers_4_35_0/data/data_collator.py ADDED
@@ -0,0 +1,1535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+ import warnings
17
+ from collections.abc import Mapping
18
+ from dataclasses import dataclass
19
+ from random import randint
20
+ from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+
24
+ from ..models.bert import BertTokenizer, BertTokenizerFast
25
+ from ..tokenization_utils_base import PreTrainedTokenizerBase
26
+ from ..utils import PaddingStrategy
27
+
28
+
29
+ InputDataClass = NewType("InputDataClass", Any)
30
+
31
+ """
32
+ A DataCollator is a function that takes a list of samples from a Dataset and collate them into a batch, as a dictionary
33
+ of PyTorch/TensorFlow tensors or NumPy arrays.
34
+ """
35
+ DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, Any]])
36
+
37
+
38
+ class DataCollatorMixin:
39
+ def __call__(self, features, return_tensors=None):
40
+ if return_tensors is None:
41
+ return_tensors = self.return_tensors
42
+ if return_tensors == "tf":
43
+ return self.tf_call(features)
44
+ elif return_tensors == "pt":
45
+ return self.torch_call(features)
46
+ elif return_tensors == "np":
47
+ return self.numpy_call(features)
48
+ else:
49
+ raise ValueError(f"Framework '{return_tensors}' not recognized!")
50
+
51
+
52
+ def default_data_collator(features: List[InputDataClass], return_tensors="pt") -> Dict[str, Any]:
53
+ """
54
+ Very simple data collator that simply collates batches of dict-like objects and performs special handling for
55
+ potential keys named:
56
+
57
+ - `label`: handles a single value (int or float) per object
58
+ - `label_ids`: handles a list of values per object
59
+
60
+ Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
61
+ to the model. See glue and ner for example of how it's useful.
62
+ """
63
+
64
+ # In this function we'll make the assumption that all `features` in the batch
65
+ # have the same attributes.
66
+ # So we will look at the first element as a proxy for what attributes exist
67
+ # on the whole batch.
68
+
69
+ if return_tensors == "pt":
70
+ return torch_default_data_collator(features)
71
+ elif return_tensors == "tf":
72
+ return tf_default_data_collator(features)
73
+ elif return_tensors == "np":
74
+ return numpy_default_data_collator(features)
75
+
76
+
77
+ @dataclass
78
+ class DefaultDataCollator(DataCollatorMixin):
79
+ """
80
+ Very simple data collator that simply collates batches of dict-like objects and performs special handling for
81
+ potential keys named:
82
+
83
+ - `label`: handles a single value (int or float) per object
84
+ - `label_ids`: handles a list of values per object
85
+
86
+ Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
87
+ to the model. See glue and ner for example of how it's useful.
88
+
89
+ This is an object (like other data collators) rather than a pure function like default_data_collator. This can be
90
+ helpful if you need to set a return_tensors value at initialization.
91
+
92
+ Args:
93
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
94
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
95
+ """
96
+
97
+ return_tensors: str = "pt"
98
+
99
+ def __call__(self, features: List[Dict[str, Any]], return_tensors=None) -> Dict[str, Any]:
100
+ if return_tensors is None:
101
+ return_tensors = self.return_tensors
102
+ return default_data_collator(features, return_tensors)
103
+
104
+
105
+ def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
106
+ import torch
107
+
108
+ if not isinstance(features[0], Mapping):
109
+ features = [vars(f) for f in features]
110
+ first = features[0]
111
+ batch = {}
112
+
113
+ # Special handling for labels.
114
+ # Ensure that tensor is created with the correct type
115
+ # (it should be automatically the case, but let's make sure of it.)
116
+ if "label" in first and first["label"] is not None:
117
+ label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
118
+ dtype = torch.long if isinstance(label, int) else torch.float
119
+ batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
120
+ elif "label_ids" in first and first["label_ids"] is not None:
121
+ if isinstance(first["label_ids"], torch.Tensor):
122
+ batch["labels"] = torch.stack([f["label_ids"] for f in features])
123
+ else:
124
+ dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
125
+ batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
126
+
127
+ # Handling of all other possible keys.
128
+ # Again, we will use the first element to figure out which key/values are not None for this model.
129
+ for k, v in first.items():
130
+ if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
131
+ if isinstance(v, torch.Tensor):
132
+ batch[k] = torch.stack([f[k] for f in features])
133
+ elif isinstance(v, np.ndarray):
134
+ batch[k] = torch.tensor(np.stack([f[k] for f in features]))
135
+ else:
136
+ batch[k] = torch.tensor([f[k] for f in features])
137
+
138
+ return batch
139
+
140
+
141
+ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
142
+ import tensorflow as tf
143
+
144
+ if not isinstance(features[0], Mapping):
145
+ features = [vars(f) for f in features]
146
+ first = features[0]
147
+ batch = {}
148
+
149
+ # Special handling for labels.
150
+ # Ensure that tensor is created with the correct type
151
+ # (it should be automatically the case, but let's make sure of it.)
152
+ if "label" in first and first["label"] is not None:
153
+ label_col_name = "label"
154
+ elif "label_ids" in first and first["label_ids"] is not None:
155
+ label_col_name = "label_ids"
156
+ elif "labels" in first and first["labels"] is not None:
157
+ label_col_name = "labels"
158
+ else:
159
+ label_col_name = None
160
+ if label_col_name is not None:
161
+ if isinstance(first[label_col_name], tf.Tensor):
162
+ dtype = tf.int64 if first[label_col_name].dtype.is_integer else tf.float32
163
+ elif isinstance(first[label_col_name], np.ndarray) or isinstance(first[label_col_name], np.generic):
164
+ dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32
165
+ elif isinstance(first[label_col_name], (tuple, list)):
166
+ dtype = tf.int64 if isinstance(first[label_col_name][0], int) else tf.float32
167
+ else:
168
+ dtype = tf.int64 if isinstance(first[label_col_name], int) else tf.float32
169
+ batch["labels"] = tf.convert_to_tensor([f[label_col_name] for f in features], dtype=dtype)
170
+ # Handling of all other possible keys.
171
+ # Again, we will use the first element to figure out which key/values are not None for this model.
172
+ for k, v in first.items():
173
+ if k not in ("label", "label_ids", "labels") and v is not None and not isinstance(v, str):
174
+ if isinstance(v, (tf.Tensor, np.ndarray)):
175
+ batch[k] = tf.stack([f[k] for f in features])
176
+ else:
177
+ batch[k] = tf.convert_to_tensor([f[k] for f in features])
178
+
179
+ return batch
180
+
181
+
182
+ def numpy_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
183
+ if not isinstance(features[0], Mapping):
184
+ features = [vars(f) for f in features]
185
+ first = features[0]
186
+ batch = {}
187
+
188
+ # Special handling for labels.
189
+ # Ensure that tensor is created with the correct type
190
+ # (it should be automatically the case, but let's make sure of it.)
191
+ if "label" in first and first["label"] is not None:
192
+ label = first["label"].item() if isinstance(first["label"], np.ndarray) else first["label"]
193
+ dtype = np.int64 if isinstance(label, int) else np.float32
194
+ batch["labels"] = np.array([f["label"] for f in features], dtype=dtype)
195
+ elif "label_ids" in first and first["label_ids"] is not None:
196
+ if isinstance(first["label_ids"], np.ndarray):
197
+ batch["labels"] = np.stack([f["label_ids"] for f in features])
198
+ else:
199
+ dtype = np.int64 if type(first["label_ids"][0]) is int else np.float32
200
+ batch["labels"] = np.array([f["label_ids"] for f in features], dtype=dtype)
201
+
202
+ # Handling of all other possible keys.
203
+ # Again, we will use the first element to figure out which key/values are not None for this model.
204
+ for k, v in first.items():
205
+ if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
206
+ if isinstance(v, np.ndarray):
207
+ batch[k] = np.stack([f[k] for f in features])
208
+ else:
209
+ batch[k] = np.array([f[k] for f in features])
210
+
211
+ return batch
212
+
213
+
214
+ @dataclass
215
+ class DataCollatorWithPadding:
216
+ """
217
+ Data collator that will dynamically pad the inputs received.
218
+
219
+ Args:
220
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
221
+ The tokenizer used for encoding the data.
222
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
223
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
224
+ among:
225
+
226
+ - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
227
+ sequence is provided).
228
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
229
+ acceptable input length for the model if that argument is not provided.
230
+ - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
231
+ max_length (`int`, *optional*):
232
+ Maximum length of the returned list and optionally padding length (see above).
233
+ pad_to_multiple_of (`int`, *optional*):
234
+ If set will pad the sequence to a multiple of the provided value.
235
+
236
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
237
+ 7.5 (Volta).
238
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
239
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
240
+ """
241
+
242
+ tokenizer: PreTrainedTokenizerBase
243
+ padding: Union[bool, str, PaddingStrategy] = True
244
+ max_length: Optional[int] = None
245
+ pad_to_multiple_of: Optional[int] = None
246
+ return_tensors: str = "pt"
247
+
248
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
249
+ batch = self.tokenizer.pad(
250
+ features,
251
+ padding=self.padding,
252
+ max_length=self.max_length,
253
+ pad_to_multiple_of=self.pad_to_multiple_of,
254
+ return_tensors=self.return_tensors,
255
+ )
256
+ if "label" in batch:
257
+ batch["labels"] = batch["label"]
258
+ del batch["label"]
259
+ if "label_ids" in batch:
260
+ batch["labels"] = batch["label_ids"]
261
+ del batch["label_ids"]
262
+ return batch
263
+
264
+
265
+ @dataclass
266
+ class DataCollatorForTokenClassification(DataCollatorMixin):
267
+ """
268
+ Data collator that will dynamically pad the inputs received, as well as the labels.
269
+
270
+ Args:
271
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
272
+ The tokenizer used for encoding the data.
273
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
274
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
275
+ among:
276
+
277
+ - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
278
+ sequence is provided).
279
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
280
+ acceptable input length for the model if that argument is not provided.
281
+ - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
282
+ max_length (`int`, *optional*):
283
+ Maximum length of the returned list and optionally padding length (see above).
284
+ pad_to_multiple_of (`int`, *optional*):
285
+ If set will pad the sequence to a multiple of the provided value.
286
+
287
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
288
+ 7.5 (Volta).
289
+ label_pad_token_id (`int`, *optional*, defaults to -100):
290
+ The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
291
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
292
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
293
+ """
294
+
295
+ tokenizer: PreTrainedTokenizerBase
296
+ padding: Union[bool, str, PaddingStrategy] = True
297
+ max_length: Optional[int] = None
298
+ pad_to_multiple_of: Optional[int] = None
299
+ label_pad_token_id: int = -100
300
+ return_tensors: str = "pt"
301
+
302
+ def torch_call(self, features):
303
+ import torch
304
+
305
+ label_name = "label" if "label" in features[0].keys() else "labels"
306
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
307
+
308
+ no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
309
+
310
+ batch = self.tokenizer.pad(
311
+ no_labels_features,
312
+ padding=self.padding,
313
+ max_length=self.max_length,
314
+ pad_to_multiple_of=self.pad_to_multiple_of,
315
+ return_tensors="pt",
316
+ )
317
+
318
+ if labels is None:
319
+ return batch
320
+
321
+ sequence_length = batch["input_ids"].shape[1]
322
+ padding_side = self.tokenizer.padding_side
323
+
324
+ def to_list(tensor_or_iterable):
325
+ if isinstance(tensor_or_iterable, torch.Tensor):
326
+ return tensor_or_iterable.tolist()
327
+ return list(tensor_or_iterable)
328
+
329
+ if padding_side == "right":
330
+ batch[label_name] = [
331
+ to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
332
+ ]
333
+ else:
334
+ batch[label_name] = [
335
+ [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
336
+ ]
337
+
338
+ batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
339
+ return batch
340
+
341
+ def tf_call(self, features):
342
+ import tensorflow as tf
343
+
344
+ label_name = "label" if "label" in features[0].keys() else "labels"
345
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
346
+ batch = self.tokenizer.pad(
347
+ features,
348
+ padding=self.padding,
349
+ max_length=self.max_length,
350
+ pad_to_multiple_of=self.pad_to_multiple_of,
351
+ # Conversion to tensors will fail if we have labels as they are not of the same length yet.
352
+ return_tensors="tf" if labels is None else None,
353
+ )
354
+
355
+ if labels is None:
356
+ return batch
357
+
358
+ sequence_length = tf.convert_to_tensor(batch["input_ids"]).shape[1]
359
+ padding_side = self.tokenizer.padding_side
360
+ if padding_side == "right":
361
+ batch["labels"] = [
362
+ list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
363
+ ]
364
+ else:
365
+ batch["labels"] = [
366
+ [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
367
+ ]
368
+
369
+ batch = {k: tf.convert_to_tensor(v, dtype=tf.int64) for k, v in batch.items()}
370
+ return batch
371
+
372
+ def numpy_call(self, features):
373
+ label_name = "label" if "label" in features[0].keys() else "labels"
374
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
375
+ batch = self.tokenizer.pad(
376
+ features,
377
+ padding=self.padding,
378
+ max_length=self.max_length,
379
+ pad_to_multiple_of=self.pad_to_multiple_of,
380
+ # Conversion to tensors will fail if we have labels as they are not of the same length yet.
381
+ return_tensors="np" if labels is None else None,
382
+ )
383
+
384
+ if labels is None:
385
+ return batch
386
+
387
+ sequence_length = np.array(batch["input_ids"]).shape[1]
388
+ padding_side = self.tokenizer.padding_side
389
+ if padding_side == "right":
390
+ batch["labels"] = [
391
+ list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
392
+ ]
393
+ else:
394
+ batch["labels"] = [
395
+ [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
396
+ ]
397
+
398
+ batch = {k: np.array(v, dtype=np.int64) for k, v in batch.items()}
399
+ return batch
400
+
401
+
402
+ def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
403
+ """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
404
+ import torch
405
+
406
+ # Tensorize if necessary.
407
+ if isinstance(examples[0], (list, tuple, np.ndarray)):
408
+ examples = [torch.tensor(e, dtype=torch.long) for e in examples]
409
+
410
+ length_of_first = examples[0].size(0)
411
+
412
+ # Check if padding is necessary.
413
+
414
+ are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
415
+ if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
416
+ return torch.stack(examples, dim=0)
417
+
418
+ # If yes, check if we have a `pad_token`.
419
+ if tokenizer._pad_token is None:
420
+ raise ValueError(
421
+ "You are attempting to pad samples but the tokenizer you are using"
422
+ f" ({tokenizer.__class__.__name__}) does not have a pad token."
423
+ )
424
+
425
+ # Creating the full tensor and filling it with our data.
426
+ max_length = max(x.size(0) for x in examples)
427
+ if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
428
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
429
+ result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
430
+ for i, example in enumerate(examples):
431
+ if tokenizer.padding_side == "right":
432
+ result[i, : example.shape[0]] = example
433
+ else:
434
+ result[i, -example.shape[0] :] = example
435
+ return result
436
+
437
+
438
+ def _tf_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
439
+ import tensorflow as tf
440
+
441
+ """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
442
+ # Tensorize if necessary.
443
+ if isinstance(examples[0], (list, tuple)):
444
+ examples = [tf.convert_to_tensor(e, dtype=tf.int64) for e in examples]
445
+
446
+ # Check if padding is necessary.
447
+ length_of_first = len(examples[0])
448
+ are_tensors_same_length = all(len(x) == length_of_first for x in examples)
449
+ if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
450
+ return tf.stack(examples, axis=0)
451
+
452
+ # If yes, check if we have a `pad_token`.
453
+ if tokenizer._pad_token is None:
454
+ raise ValueError(
455
+ "You are attempting to pad samples but the tokenizer you are using"
456
+ f" ({tokenizer.__class__.__name__}) does not have a pad token."
457
+ )
458
+
459
+ # Creating the full tensor and filling it with our data.
460
+ max_length = max(len(x) for x in examples)
461
+ if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
462
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
463
+ # result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
464
+ result = []
465
+ rank = tf.rank(examples[0])
466
+ paddings = np.zeros((rank, 2), dtype=np.int32)
467
+ for example in examples:
468
+ if tokenizer.padding_side == "right":
469
+ paddings[0, 1] = max_length - len(example)
470
+ else:
471
+ paddings[0, 0] = max_length - len(example)
472
+ result.append(tf.pad(example, paddings, constant_values=tokenizer.pad_token_id))
473
+ return tf.stack(result, axis=0)
474
+
475
+
476
+ def _numpy_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
477
+ """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
478
+ # Tensorize if necessary.
479
+ if isinstance(examples[0], (list, tuple)):
480
+ examples = [np.array(e, dtype=np.int64) for e in examples]
481
+
482
+ # Check if padding is necessary.
483
+ length_of_first = len(examples[0])
484
+ are_tensors_same_length = all(len(x) == length_of_first for x in examples)
485
+ if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
486
+ return np.stack(examples, axis=0)
487
+
488
+ # If yes, check if we have a `pad_token`.
489
+ if tokenizer._pad_token is None:
490
+ raise ValueError(
491
+ "You are attempting to pad samples but the tokenizer you are using"
492
+ f" ({tokenizer.__class__.__name__}) does not have a pad token."
493
+ )
494
+
495
+ # Creating the full tensor and filling it with our data.
496
+ max_length = max(len(x) for x in examples)
497
+ if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
498
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
499
+ result = np.full(shape=(len(examples), max_length), fill_value=tokenizer.pad_token_id, dtype=examples[0].dtype)
500
+ for i, example in enumerate(examples):
501
+ if tokenizer.padding_side == "right":
502
+ result[i, : example.shape[0]] = example
503
+ else:
504
+ result[i, -example.shape[0] :] = example
505
+ return result
506
+
507
+
508
+ def tolist(x):
509
+ if isinstance(x, list):
510
+ return x
511
+ elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import
512
+ x = x.numpy()
513
+ return x.tolist()
514
+
515
+
516
+ @dataclass
517
+ class DataCollatorForSeq2Seq:
518
+ """
519
+ Data collator that will dynamically pad the inputs received, as well as the labels.
520
+
521
+ Args:
522
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
523
+ The tokenizer used for encoding the data.
524
+ model ([`PreTrainedModel`], *optional*):
525
+ The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
526
+ prepare the *decoder_input_ids*
527
+
528
+ This is useful when using *label_smoothing* to avoid calculating loss twice.
529
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
530
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
531
+ among:
532
+
533
+ - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
534
+ sequence is provided).
535
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
536
+ acceptable input length for the model if that argument is not provided.
537
+ - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
538
+ max_length (`int`, *optional*):
539
+ Maximum length of the returned list and optionally padding length (see above).
540
+ pad_to_multiple_of (`int`, *optional*):
541
+ If set will pad the sequence to a multiple of the provided value.
542
+
543
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
544
+ 7.5 (Volta).
545
+ label_pad_token_id (`int`, *optional*, defaults to -100):
546
+ The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
547
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
548
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
549
+ """
550
+
551
+ tokenizer: PreTrainedTokenizerBase
552
+ model: Optional[Any] = None
553
+ padding: Union[bool, str, PaddingStrategy] = True
554
+ max_length: Optional[int] = None
555
+ pad_to_multiple_of: Optional[int] = None
556
+ label_pad_token_id: int = -100
557
+ return_tensors: str = "pt"
558
+
559
+ def __call__(self, features, return_tensors=None):
560
+ if return_tensors is None:
561
+ return_tensors = self.return_tensors
562
+ labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
563
+ # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
564
+ # same length to return tensors.
565
+ if labels is not None:
566
+ max_label_length = max(len(l) for l in labels)
567
+ if self.pad_to_multiple_of is not None:
568
+ max_label_length = (
569
+ (max_label_length + self.pad_to_multiple_of - 1)
570
+ // self.pad_to_multiple_of
571
+ * self.pad_to_multiple_of
572
+ )
573
+
574
+ padding_side = self.tokenizer.padding_side
575
+ for feature in features:
576
+ remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
577
+ if isinstance(feature["labels"], list):
578
+ feature["labels"] = (
579
+ feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
580
+ )
581
+ elif padding_side == "right":
582
+ feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
583
+ else:
584
+ feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
585
+
586
+ features = self.tokenizer.pad(
587
+ features,
588
+ padding=self.padding,
589
+ max_length=self.max_length,
590
+ pad_to_multiple_of=self.pad_to_multiple_of,
591
+ return_tensors=return_tensors,
592
+ )
593
+
594
+ # prepare decoder_input_ids
595
+ if (
596
+ labels is not None
597
+ and self.model is not None
598
+ and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
599
+ ):
600
+ decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"])
601
+ features["decoder_input_ids"] = decoder_input_ids
602
+
603
+ return features
604
+
605
+
606
+ @dataclass
607
+ class DataCollatorForLanguageModeling(DataCollatorMixin):
608
+ """
609
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
610
+ are not all of the same length.
611
+
612
+ Args:
613
+ tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
614
+ The tokenizer used for encoding the data.
615
+ mlm (`bool`, *optional*, defaults to `True`):
616
+ Whether or not to use masked language modeling. If set to `False`, the labels are the same as the inputs
617
+ with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for non-masked
618
+ tokens and the value to predict for the masked token.
619
+ mlm_probability (`float`, *optional*, defaults to 0.15):
620
+ The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`.
621
+ pad_to_multiple_of (`int`, *optional*):
622
+ If set will pad the sequence to a multiple of the provided value.
623
+ return_tensors (`str`):
624
+ The type of Tensor to return. Allowable values are "np", "pt" and "tf".
625
+
626
+ <Tip>
627
+
628
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
629
+ BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a
630
+ [`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`.
631
+
632
+ </Tip>"""
633
+
634
+ tokenizer: PreTrainedTokenizerBase
635
+ mlm: bool = True
636
+ mlm_probability: float = 0.15
637
+ pad_to_multiple_of: Optional[int] = None
638
+ tf_experimental_compile: bool = False
639
+ return_tensors: str = "pt"
640
+
641
+ def __post_init__(self):
642
+ if self.mlm and self.tokenizer.mask_token is None:
643
+ raise ValueError(
644
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
645
+ "You should pass `mlm=False` to train on causal language modeling instead."
646
+ )
647
+ if self.tf_experimental_compile:
648
+ import tensorflow as tf
649
+
650
+ self.tf_mask_tokens = tf.function(self.tf_mask_tokens, jit_compile=True)
651
+
652
+ @staticmethod
653
+ def tf_bernoulli(shape, probability):
654
+ import tensorflow as tf
655
+
656
+ prob_matrix = tf.fill(shape, probability)
657
+ return tf.cast(prob_matrix - tf.random.uniform(shape, 0, 1) >= 0, tf.bool)
658
+
659
+ def tf_mask_tokens(
660
+ self, inputs: Any, vocab_size, mask_token_id, special_tokens_mask: Optional[Any] = None
661
+ ) -> Tuple[Any, Any]:
662
+ """
663
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
664
+ """
665
+ import tensorflow as tf
666
+
667
+ mask_token_id = tf.cast(mask_token_id, inputs.dtype)
668
+
669
+ input_shape = tf.shape(inputs)
670
+ # 1 for a special token, 0 for a normal token in the special tokens mask
671
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
672
+ masked_indices = self.tf_bernoulli(input_shape, self.mlm_probability) & ~special_tokens_mask
673
+ # Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
674
+ labels = tf.where(masked_indices, inputs, -100)
675
+
676
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
677
+ indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
678
+
679
+ inputs = tf.where(indices_replaced, mask_token_id, inputs)
680
+
681
+ # 10% of the time, we replace masked input tokens with random word
682
+ indices_random = self.tf_bernoulli(input_shape, 0.1) & masked_indices & ~indices_replaced
683
+ random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)
684
+
685
+ inputs = tf.where(indices_random, random_words, inputs)
686
+
687
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
688
+ return inputs, labels
689
+
690
+ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
691
+ import tensorflow as tf
692
+
693
+ # Handle dict or lists with proper padding and conversion to tensor.
694
+ if isinstance(examples[0], Mapping):
695
+ batch = self.tokenizer.pad(examples, return_tensors="tf", pad_to_multiple_of=self.pad_to_multiple_of)
696
+ else:
697
+ batch = {
698
+ "input_ids": _tf_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
699
+ }
700
+
701
+ # If special token mask has been preprocessed, pop it from the dict.
702
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
703
+ if self.mlm:
704
+ if special_tokens_mask is None:
705
+ special_tokens_mask = [
706
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
707
+ for val in batch["input_ids"].numpy().tolist()
708
+ ]
709
+ # Cannot directly create as bool
710
+ special_tokens_mask = tf.cast(tf.convert_to_tensor(special_tokens_mask, dtype=tf.int64), tf.bool)
711
+ else:
712
+ special_tokens_mask = tf.cast(special_tokens_mask, tf.bool)
713
+ batch["input_ids"], batch["labels"] = self.tf_mask_tokens(
714
+ tf.cast(batch["input_ids"], tf.int64),
715
+ special_tokens_mask=special_tokens_mask,
716
+ mask_token_id=self.tokenizer.mask_token_id,
717
+ vocab_size=len(self.tokenizer),
718
+ )
719
+ else:
720
+ labels = batch["input_ids"]
721
+ if self.tokenizer.pad_token_id is not None:
722
+ # Replace self.tokenizer.pad_token_id with -100
723
+ labels = tf.where(labels == self.tokenizer.pad_token_id, -100, labels)
724
+ else:
725
+ labels = tf.identity(labels) # Makes a copy, just in case
726
+ batch["labels"] = labels
727
+ return batch
728
+
729
+ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
730
+ # Handle dict or lists with proper padding and conversion to tensor.
731
+ if isinstance(examples[0], Mapping):
732
+ batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
733
+ else:
734
+ batch = {
735
+ "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
736
+ }
737
+
738
+ # If special token mask has been preprocessed, pop it from the dict.
739
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
740
+ if self.mlm:
741
+ batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
742
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
743
+ )
744
+ else:
745
+ labels = batch["input_ids"].clone()
746
+ if self.tokenizer.pad_token_id is not None:
747
+ labels[labels == self.tokenizer.pad_token_id] = -100
748
+ batch["labels"] = labels
749
+ return batch
750
+
751
+ def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
752
+ """
753
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
754
+ """
755
+ import torch
756
+
757
+ labels = inputs.clone()
758
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
759
+ probability_matrix = torch.full(labels.shape, self.mlm_probability)
760
+ if special_tokens_mask is None:
761
+ special_tokens_mask = [
762
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
763
+ ]
764
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
765
+ else:
766
+ special_tokens_mask = special_tokens_mask.bool()
767
+
768
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
769
+ masked_indices = torch.bernoulli(probability_matrix).bool()
770
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
771
+
772
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
773
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
774
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
775
+
776
+ # 10% of the time, we replace masked input tokens with random word
777
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
778
+ random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
779
+ inputs[indices_random] = random_words[indices_random]
780
+
781
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
782
+ return inputs, labels
783
+
784
+ def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
785
+ # Handle dict or lists with proper padding and conversion to tensor.
786
+ if isinstance(examples[0], Mapping):
787
+ batch = self.tokenizer.pad(examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of)
788
+ else:
789
+ batch = {
790
+ "input_ids": _numpy_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
791
+ }
792
+
793
+ # If special token mask has been preprocessed, pop it from the dict.
794
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
795
+ if self.mlm:
796
+ batch["input_ids"], batch["labels"] = self.numpy_mask_tokens(
797
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
798
+ )
799
+ else:
800
+ labels = np.copy(batch["input_ids"])
801
+ if self.tokenizer.pad_token_id is not None:
802
+ labels[labels == self.tokenizer.pad_token_id] = -100
803
+ batch["labels"] = labels
804
+ return batch
805
+
806
+ def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
807
+ """
808
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
809
+ """
810
+ labels = np.copy(inputs)
811
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
812
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
813
+ if special_tokens_mask is None:
814
+ special_tokens_mask = [
815
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
816
+ ]
817
+ special_tokens_mask = np.array(special_tokens_mask, dtype=bool)
818
+ else:
819
+ special_tokens_mask = special_tokens_mask.astype(bool)
820
+
821
+ probability_matrix[special_tokens_mask] = 0
822
+ # Numpy doesn't have bernoulli, so we use a binomial with 1 trial
823
+ masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
824
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
825
+
826
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
827
+ indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
828
+ inputs[indices_replaced] = self.tokenizer.mask_token_id
829
+
830
+ # 10% of the time, we replace masked input tokens with random word
831
+ # indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
832
+ indices_random = (
833
+ np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
834
+ )
835
+ random_words = np.random.randint(
836
+ low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
837
+ )
838
+ inputs[indices_random] = random_words
839
+
840
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
841
+ return inputs, labels
842
+
843
+
844
+ @dataclass
845
+ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
846
+ """
847
+ Data collator used for language modeling that masks entire words.
848
+
849
+ - collates batches of tensors, honoring their tokenizer's pad_token
850
+ - preprocesses batches for masked language modeling
851
+
852
+ <Tip>
853
+
854
+ This collator relies on details of the implementation of subword tokenization by [`BertTokenizer`], specifically
855
+ that subword tokens are prefixed with *##*. For tokenizers that do not adhere to this scheme, this collator will
856
+ produce an output that is roughly equivalent to [`.DataCollatorForLanguageModeling`].
857
+
858
+ </Tip>"""
859
+
860
+ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
861
+ if isinstance(examples[0], Mapping):
862
+ input_ids = [e["input_ids"] for e in examples]
863
+ else:
864
+ input_ids = examples
865
+ examples = [{"input_ids": e} for e in examples]
866
+
867
+ batch_input = _torch_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
868
+
869
+ mask_labels = []
870
+ for e in examples:
871
+ ref_tokens = []
872
+ for id in tolist(e["input_ids"]):
873
+ token = self.tokenizer._convert_id_to_token(id)
874
+ ref_tokens.append(token)
875
+
876
+ # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
877
+ if "chinese_ref" in e:
878
+ ref_pos = tolist(e["chinese_ref"])
879
+ len_seq = len(e["input_ids"])
880
+ for i in range(len_seq):
881
+ if i in ref_pos:
882
+ ref_tokens[i] = "##" + ref_tokens[i]
883
+ mask_labels.append(self._whole_word_mask(ref_tokens))
884
+ batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
885
+ inputs, labels = self.torch_mask_tokens(batch_input, batch_mask)
886
+ return {"input_ids": inputs, "labels": labels}
887
+
888
+ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
889
+ import tensorflow as tf
890
+
891
+ if isinstance(examples[0], Mapping):
892
+ input_ids = [e["input_ids"] for e in examples]
893
+ else:
894
+ input_ids = examples
895
+ examples = [{"input_ids": e} for e in examples]
896
+
897
+ batch_input = _tf_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
898
+
899
+ mask_labels = []
900
+ for e in examples:
901
+ ref_tokens = []
902
+ for id in tolist(e["input_ids"]):
903
+ token = self.tokenizer._convert_id_to_token(id)
904
+ ref_tokens.append(token)
905
+
906
+ # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
907
+ if "chinese_ref" in e:
908
+ ref_pos = tolist(e["chinese_ref"])
909
+ len_seq = len(e["input_ids"])
910
+ for i in range(len_seq):
911
+ if i in ref_pos:
912
+ ref_tokens[i] = "##" + ref_tokens[i]
913
+ mask_labels.append(self._whole_word_mask(ref_tokens))
914
+ batch_mask = _tf_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
915
+ inputs, labels = self.tf_mask_tokens(tf.cast(batch_input, tf.int64), batch_mask)
916
+ return {"input_ids": inputs, "labels": labels}
917
+
918
+ def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
919
+ if isinstance(examples[0], Mapping):
920
+ input_ids = [e["input_ids"] for e in examples]
921
+ else:
922
+ input_ids = examples
923
+ examples = [{"input_ids": e} for e in examples]
924
+
925
+ batch_input = _numpy_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
926
+
927
+ mask_labels = []
928
+ for e in examples:
929
+ ref_tokens = []
930
+ for id in tolist(e["input_ids"]):
931
+ token = self.tokenizer._convert_id_to_token(id)
932
+ ref_tokens.append(token)
933
+
934
+ # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
935
+ if "chinese_ref" in e:
936
+ ref_pos = tolist(e["chinese_ref"])
937
+ len_seq = len(e["input_ids"])
938
+ for i in range(len_seq):
939
+ if i in ref_pos:
940
+ ref_tokens[i] = "##" + ref_tokens[i]
941
+ mask_labels.append(self._whole_word_mask(ref_tokens))
942
+ batch_mask = _numpy_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
943
+ inputs, labels = self.numpy_mask_tokens(batch_input, batch_mask)
944
+ return {"input_ids": inputs, "labels": labels}
945
+
946
+ def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
947
+ """
948
+ Get 0/1 labels for masked tokens with whole word mask proxy
949
+ """
950
+ if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
951
+ warnings.warn(
952
+ "DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. "
953
+ "Please refer to the documentation for more information."
954
+ )
955
+
956
+ cand_indexes = []
957
+ for i, token in enumerate(input_tokens):
958
+ if token == "[CLS]" or token == "[SEP]":
959
+ continue
960
+
961
+ if len(cand_indexes) >= 1 and token.startswith("##"):
962
+ cand_indexes[-1].append(i)
963
+ else:
964
+ cand_indexes.append([i])
965
+
966
+ random.shuffle(cand_indexes)
967
+ num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
968
+ masked_lms = []
969
+ covered_indexes = set()
970
+ for index_set in cand_indexes:
971
+ if len(masked_lms) >= num_to_predict:
972
+ break
973
+ # If adding a whole-word mask would exceed the maximum number of
974
+ # predictions, then just skip this candidate.
975
+ if len(masked_lms) + len(index_set) > num_to_predict:
976
+ continue
977
+ is_any_index_covered = False
978
+ for index in index_set:
979
+ if index in covered_indexes:
980
+ is_any_index_covered = True
981
+ break
982
+ if is_any_index_covered:
983
+ continue
984
+ for index in index_set:
985
+ covered_indexes.add(index)
986
+ masked_lms.append(index)
987
+
988
+ if len(covered_indexes) != len(masked_lms):
989
+ raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
990
+ mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
991
+ return mask_labels
992
+
993
+ def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
994
+ """
995
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
996
+ 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
997
+ """
998
+ import torch
999
+
1000
+ if self.tokenizer.mask_token is None:
1001
+ raise ValueError(
1002
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
1003
+ " --mlm flag if you want to use this tokenizer."
1004
+ )
1005
+ labels = inputs.clone()
1006
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
1007
+
1008
+ probability_matrix = mask_labels
1009
+
1010
+ special_tokens_mask = [
1011
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
1012
+ ]
1013
+ probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
1014
+ if self.tokenizer._pad_token is not None:
1015
+ padding_mask = labels.eq(self.tokenizer.pad_token_id)
1016
+ probability_matrix.masked_fill_(padding_mask, value=0.0)
1017
+
1018
+ masked_indices = probability_matrix.bool()
1019
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
1020
+
1021
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1022
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
1023
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
1024
+
1025
+ # 10% of the time, we replace masked input tokens with random word
1026
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
1027
+ random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
1028
+ inputs[indices_random] = random_words[indices_random]
1029
+
1030
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
1031
+ return inputs, labels
1032
+
1033
+ def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
1034
+ """
1035
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
1036
+ 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
1037
+ """
1038
+ import tensorflow as tf
1039
+
1040
+ input_shape = tf.shape(inputs)
1041
+ if self.tokenizer.mask_token is None:
1042
+ raise ValueError(
1043
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
1044
+ " --mlm flag if you want to use this tokenizer."
1045
+ )
1046
+ labels = tf.identity(inputs)
1047
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
1048
+
1049
+ masked_indices = tf.cast(mask_labels, tf.bool)
1050
+
1051
+ special_tokens_mask = [
1052
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels
1053
+ ]
1054
+ masked_indices = masked_indices & ~tf.cast(special_tokens_mask, dtype=tf.bool)
1055
+ if self.tokenizer._pad_token is not None:
1056
+ padding_mask = inputs == self.tokenizer.pad_token_id
1057
+ masked_indices = masked_indices & ~padding_mask
1058
+
1059
+ # Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
1060
+ labels = tf.where(masked_indices, inputs, -100)
1061
+
1062
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1063
+ indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
1064
+
1065
+ inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs)
1066
+
1067
+ # 10% of the time, we replace masked input tokens with random word
1068
+ indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
1069
+ random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
1070
+ inputs = tf.where(indices_random, random_words, inputs)
1071
+
1072
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
1073
+ return inputs, labels
1074
+
1075
+ def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
1076
+ """
1077
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
1078
+ 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
1079
+ """
1080
+ if self.tokenizer.mask_token is None:
1081
+ raise ValueError(
1082
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
1083
+ " --mlm flag if you want to use this tokenizer."
1084
+ )
1085
+ labels = np.copy(inputs)
1086
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
1087
+
1088
+ masked_indices = mask_labels.astype(bool)
1089
+
1090
+ special_tokens_mask = [
1091
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
1092
+ ]
1093
+ masked_indices[np.array(special_tokens_mask, dtype=bool)] = 0
1094
+ if self.tokenizer._pad_token is not None:
1095
+ padding_mask = labels == self.tokenizer.pad_token_id
1096
+ masked_indices[padding_mask] = 0
1097
+
1098
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
1099
+
1100
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1101
+ indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
1102
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
1103
+
1104
+ # 10% of the time, we replace masked input tokens with random word
1105
+ # indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
1106
+ indices_random = (
1107
+ np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
1108
+ )
1109
+ random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
1110
+ inputs[indices_random] = random_words[indices_random]
1111
+
1112
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
1113
+ return inputs, labels
1114
+
1115
+
1116
+ @dataclass
1117
+ class DataCollatorForSOP(DataCollatorForLanguageModeling):
1118
+ """
1119
+ Data collator used for sentence order prediction task.
1120
+
1121
+ - collates batches of tensors, honoring their tokenizer's pad_token
1122
+ - preprocesses batches for both masked language modeling and sentence order prediction
1123
+ """
1124
+
1125
+ def __init__(self, *args, **kwargs):
1126
+ warnings.warn(
1127
+ "DataCollatorForSOP is deprecated and will be removed in a future version, you can now use "
1128
+ "DataCollatorForLanguageModeling instead.",
1129
+ FutureWarning,
1130
+ )
1131
+
1132
+ def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]:
1133
+ import torch
1134
+ from torch.nn.utils.rnn import pad_sequence
1135
+
1136
+ input_ids = [example["input_ids"] for example in examples]
1137
+ input_ids = _torch_collate_batch(input_ids, self.tokenizer)
1138
+ input_ids, labels, attention_mask = self.mask_tokens(input_ids)
1139
+
1140
+ token_type_ids = [example["token_type_ids"] for example in examples]
1141
+ # size of segment_ids varied because randomness, padding zero to the end as the original implementation
1142
+ token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
1143
+
1144
+ sop_label_list = [example["sentence_order_label"] for example in examples]
1145
+ sentence_order_label = torch.stack(sop_label_list)
1146
+
1147
+ return {
1148
+ "input_ids": input_ids,
1149
+ "labels": labels,
1150
+ "attention_mask": attention_mask,
1151
+ "token_type_ids": token_type_ids,
1152
+ "sentence_order_label": sentence_order_label,
1153
+ }
1154
+
1155
+ def mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any]:
1156
+ """
1157
+ Prepare masked tokens inputs/labels/attention_mask for masked language modeling: 80% MASK, 10% random, 10%
1158
+ original. N-gram not applied yet.
1159
+ """
1160
+ import torch
1161
+
1162
+ if self.tokenizer.mask_token is None:
1163
+ raise ValueError(
1164
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
1165
+ " --mlm flag if you want to use this tokenizer."
1166
+ )
1167
+
1168
+ labels = inputs.clone()
1169
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
1170
+ probability_matrix = torch.full(labels.shape, self.mlm_probability)
1171
+ special_tokens_mask = [
1172
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
1173
+ ]
1174
+ probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
1175
+ if self.tokenizer._pad_token is not None:
1176
+ padding_mask = labels.eq(self.tokenizer.pad_token_id)
1177
+ probability_matrix.masked_fill_(padding_mask, value=0.0)
1178
+ masked_indices = torch.bernoulli(probability_matrix).bool()
1179
+ # probability be `1` (masked), however in albert model attention mask `0` means masked, revert the value
1180
+ attention_mask = (~masked_indices).float()
1181
+ if self.tokenizer._pad_token is not None:
1182
+ attention_padding_mask = labels.eq(self.tokenizer.pad_token_id)
1183
+ attention_mask.masked_fill_(attention_padding_mask, value=1.0)
1184
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens, -100 is default for CE compute
1185
+
1186
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
1187
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
1188
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
1189
+
1190
+ # 10% of the time, we replace masked input tokens with random word
1191
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
1192
+ random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
1193
+ inputs[indices_random] = random_words[indices_random]
1194
+
1195
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
1196
+ return inputs, labels, attention_mask
1197
+
1198
+
1199
+ @dataclass
1200
+ class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
1201
+ """
1202
+ Data collator used for permutation language modeling.
1203
+
1204
+ - collates batches of tensors, honoring their tokenizer's pad_token
1205
+ - preprocesses batches for permutation language modeling with procedures specific to XLNet
1206
+ """
1207
+
1208
+ tokenizer: PreTrainedTokenizerBase
1209
+ plm_probability: float = 1 / 6
1210
+ max_span_length: int = 5 # maximum length of a span of masked tokens
1211
+ return_tensors: str = "pt"
1212
+
1213
+ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
1214
+ if isinstance(examples[0], Mapping):
1215
+ examples = [e["input_ids"] for e in examples]
1216
+ batch = _torch_collate_batch(examples, self.tokenizer)
1217
+ inputs, perm_mask, target_mapping, labels = self.torch_mask_tokens(batch)
1218
+ return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
1219
+
1220
+ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
1221
+ if isinstance(examples[0], Mapping):
1222
+ examples = [e["input_ids"] for e in examples]
1223
+ batch = _tf_collate_batch(examples, self.tokenizer)
1224
+ inputs, perm_mask, target_mapping, labels = self.tf_mask_tokens(batch)
1225
+ return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
1226
+
1227
+ def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
1228
+ if isinstance(examples[0], Mapping):
1229
+ examples = [e["input_ids"] for e in examples]
1230
+ batch = _numpy_collate_batch(examples, self.tokenizer)
1231
+ inputs, perm_mask, target_mapping, labels = self.numpy_mask_tokens(batch)
1232
+ return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
1233
+
1234
+ def torch_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
1235
+ """
1236
+ The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
1237
+
1238
+ 0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1239
+ 1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
1240
+ 2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
1241
+ masked
1242
+ 3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
1243
+ span_length]` and mask tokens `start_index:start_index + span_length`
1244
+ 4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
1245
+ sequence to be processed), repeat from Step 1.
1246
+ """
1247
+ import torch
1248
+
1249
+ if self.tokenizer.mask_token is None:
1250
+ raise ValueError(
1251
+ "This tokenizer does not have a mask token which is necessary for permutation language modeling."
1252
+ " Please add a mask token if you want to use this tokenizer."
1253
+ )
1254
+
1255
+ if inputs.size(1) % 2 != 0:
1256
+ raise ValueError(
1257
+ "This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
1258
+ " relevant comments in source code for details."
1259
+ )
1260
+
1261
+ labels = inputs.clone()
1262
+ # Creating the mask and target_mapping tensors
1263
+ masked_indices = torch.full(labels.shape, 0, dtype=torch.bool)
1264
+ target_mapping = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
1265
+
1266
+ for i in range(labels.size(0)):
1267
+ # Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1268
+ cur_len = 0
1269
+ max_len = labels.size(1)
1270
+
1271
+ while cur_len < max_len:
1272
+ # Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
1273
+ span_length = torch.randint(1, self.max_span_length + 1, (1,)).item()
1274
+ # Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
1275
+ context_length = int(span_length / self.plm_probability)
1276
+ # Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
1277
+ start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item()
1278
+ masked_indices[i, start_index : start_index + span_length] = 1
1279
+ # Set `cur_len = cur_len + context_length`
1280
+ cur_len += context_length
1281
+
1282
+ # Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
1283
+ # the i-th predict corresponds to the i-th token.
1284
+ target_mapping[i] = torch.eye(labels.size(1))
1285
+
1286
+ special_tokens_mask = torch.tensor(
1287
+ [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
1288
+ dtype=torch.bool,
1289
+ )
1290
+ masked_indices.masked_fill_(special_tokens_mask, value=0.0)
1291
+ if self.tokenizer._pad_token is not None:
1292
+ padding_mask = labels.eq(self.tokenizer.pad_token_id)
1293
+ masked_indices.masked_fill_(padding_mask, value=0.0)
1294
+
1295
+ # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
1296
+ non_func_mask = ~(padding_mask | special_tokens_mask)
1297
+
1298
+ inputs[masked_indices] = self.tokenizer.mask_token_id
1299
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
1300
+
1301
+ perm_mask = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
1302
+
1303
+ for i in range(labels.size(0)):
1304
+ # Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
1305
+ # determine which tokens a given token can attend to (encoded in `perm_mask`).
1306
+ # Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
1307
+ # (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
1308
+ # we assume that reused length is half of sequence length and permutation length is equal to reused length.
1309
+ # This requires that the sequence length be even.
1310
+
1311
+ # Create a linear factorisation order
1312
+ perm_index = torch.arange(labels.size(1))
1313
+ # Split this into two halves, assuming that half the sequence is reused each time
1314
+ perm_index = perm_index.reshape((-1, labels.size(1) // 2)).transpose(0, 1)
1315
+ # Permute the two halves such that they do not cross over
1316
+ perm_index = perm_index[torch.randperm(labels.size(1) // 2)]
1317
+ # Flatten this out into the desired permuted factorisation order
1318
+ perm_index = torch.flatten(perm_index.transpose(0, 1))
1319
+ # Set the permutation indices of non-masked (non-functional) tokens to the
1320
+ # smallest index (-1) so that:
1321
+ # (1) They can be seen by all other positions
1322
+ # (2) They cannot see masked positions, so there won't be information leak
1323
+ perm_index.masked_fill_(~masked_indices[i] & non_func_mask[i], -1)
1324
+ # The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
1325
+ # 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
1326
+ # 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
1327
+ perm_mask[i] = (
1328
+ perm_index.reshape((labels.size(1), 1)) <= perm_index.reshape((1, labels.size(1)))
1329
+ ) & masked_indices[i]
1330
+
1331
+ return inputs.long(), perm_mask, target_mapping, labels.long()
1332
+
1333
+ def tf_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
1334
+ """
1335
+ The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
1336
+
1337
+ 0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1338
+ 1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
1339
+ 2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
1340
+ masked
1341
+ 3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
1342
+ span_length]` and mask tokens `start_index:start_index + span_length`
1343
+ 4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
1344
+ sequence to be processed), repeat from Step 1.
1345
+ """
1346
+ import tensorflow as tf
1347
+
1348
+ if self.tokenizer.mask_token is None:
1349
+ raise ValueError(
1350
+ "This tokenizer does not have a mask token which is necessary for permutation language modeling."
1351
+ " Please add a mask token if you want to use this tokenizer."
1352
+ )
1353
+
1354
+ if tf.shape(inputs)[1] % 2 != 0:
1355
+ raise ValueError(
1356
+ "This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
1357
+ " relevant comments in source code for details."
1358
+ )
1359
+
1360
+ labels = tf.identity(inputs)
1361
+ # Creating the mask and target_mapping tensors
1362
+ masked_indices = np.full(labels.shape.as_list(), 0, dtype=bool)
1363
+ labels_shape = tf.shape(labels)
1364
+ target_mapping = np.zeros((labels_shape[0], labels_shape[1], labels_shape[1]), dtype=np.float32)
1365
+
1366
+ for i in range(len(labels)):
1367
+ # Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1368
+ cur_len = 0
1369
+ max_len = tf.shape(labels)[1]
1370
+
1371
+ while cur_len < max_len:
1372
+ # Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
1373
+ span_length = randint(1, self.max_span_length + 1)
1374
+ # Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
1375
+ context_length = int(span_length / self.plm_probability)
1376
+ # Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
1377
+ start_index = cur_len + randint(0, context_length - span_length + 1)
1378
+ masked_indices[i, start_index : start_index + span_length] = 1
1379
+ # Set `cur_len = cur_len + context_length`
1380
+ cur_len += context_length
1381
+
1382
+ # Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
1383
+ # the i-th predict corresponds to the i-th token.
1384
+ target_mapping[i] = np.eye(labels_shape[1])
1385
+ masked_indices = tf.cast(tf.convert_to_tensor(masked_indices), dtype=tf.bool)
1386
+ target_mapping = tf.convert_to_tensor(target_mapping)
1387
+ special_tokens_mask = tf.convert_to_tensor(
1388
+ [
1389
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
1390
+ for val in labels.numpy().tolist()
1391
+ ],
1392
+ )
1393
+ special_tokens_mask = tf.cast(special_tokens_mask, dtype=tf.bool)
1394
+ masked_indices = masked_indices & ~special_tokens_mask
1395
+ if self.tokenizer._pad_token is not None:
1396
+ padding_mask = labels == self.tokenizer.pad_token_id
1397
+ masked_indices = masked_indices & ~padding_mask
1398
+
1399
+ # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
1400
+ non_func_mask = ~(padding_mask | special_tokens_mask)
1401
+
1402
+ inputs = tf.where(masked_indices, self.tokenizer.mask_token_id, inputs)
1403
+ labels = tf.where(masked_indices, labels, -100) # We only compute loss on masked tokens
1404
+
1405
+ perm_mask = []
1406
+
1407
+ for i in range(len(labels)):
1408
+ # Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
1409
+ # determine which tokens a given token can attend to (encoded in `perm_mask`).
1410
+ # Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
1411
+ # (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
1412
+ # we assume that reused length is half of sequence length and permutation length is equal to reused length.
1413
+ # This requires that the sequence length be even.
1414
+
1415
+ # Create a linear factorisation order
1416
+ # tf.range is the equivalent of torch.arange
1417
+ perm_index = tf.range(labels_shape[1])
1418
+ # Split this into two halves, assuming that half the sequence is reused each time
1419
+ perm_index = tf.transpose(tf.reshape(perm_index, (-1, labels_shape[1] // 2)))
1420
+ # Permute the two halves such that they do not cross over
1421
+ perm_index = tf.random.shuffle(perm_index) # Shuffles along the first dimension
1422
+ # Flatten this out into the desired permuted factorisation order
1423
+ perm_index = tf.reshape(tf.transpose(perm_index), (-1,))
1424
+ # Set the permutation indices of non-masked (non-functional) tokens to the
1425
+ # smallest index (-1) so that:
1426
+ # (1) They can be seen by all other positions
1427
+ # (2) They cannot see masked positions, so there won't be information leak
1428
+ perm_index = tf.where(~masked_indices[i] & non_func_mask[i], -1, perm_index)
1429
+ # The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
1430
+ # 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
1431
+ # 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
1432
+ perm_mask.append(
1433
+ (tf.reshape(perm_index, (labels_shape[1], 1)) <= tf.reshape(perm_index, (1, labels_shape[1])))
1434
+ & masked_indices[i]
1435
+ )
1436
+ perm_mask = tf.stack(perm_mask, axis=0)
1437
+
1438
+ return tf.cast(inputs, tf.int64), tf.cast(perm_mask, tf.float32), target_mapping, tf.cast(labels, tf.int64)
1439
+
1440
+ def numpy_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
1441
+ """
1442
+ The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
1443
+
1444
+ 0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1445
+ 1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
1446
+ 2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
1447
+ masked
1448
+ 3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
1449
+ span_length]` and mask tokens `start_index:start_index + span_length`
1450
+ 4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
1451
+ sequence to be processed), repeat from Step 1.
1452
+ """
1453
+ if self.tokenizer.mask_token is None:
1454
+ raise ValueError(
1455
+ "This tokenizer does not have a mask token which is necessary for permutation language modeling."
1456
+ " Please add a mask token if you want to use this tokenizer."
1457
+ )
1458
+
1459
+ if inputs.shape[1] % 2 != 0:
1460
+ raise ValueError(
1461
+ "This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
1462
+ " relevant comments in source code for details."
1463
+ )
1464
+
1465
+ labels = np.copy(inputs)
1466
+ # Creating the mask and target_mapping tensors
1467
+ masked_indices = np.full(labels.shape, 0, dtype=bool)
1468
+ target_mapping = np.zeros((labels.shape[0], labels.shape[1], labels.shape[1]), dtype=np.float32)
1469
+
1470
+ for i in range(labels.shape[0]):
1471
+ # Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
1472
+ cur_len = 0
1473
+ max_len = labels.shape[1]
1474
+
1475
+ while cur_len < max_len:
1476
+ # Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
1477
+ span_length = randint(1, self.max_span_length + 1)
1478
+ # Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
1479
+ context_length = int(span_length / self.plm_probability)
1480
+ # Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
1481
+ start_index = cur_len + randint(0, context_length - span_length + 1)
1482
+ masked_indices[i, start_index : start_index + span_length] = 1
1483
+ # Set `cur_len = cur_len + context_length`
1484
+ cur_len += context_length
1485
+
1486
+ # Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
1487
+ # the i-th predict corresponds to the i-th token.
1488
+ target_mapping[i] = np.eye(labels.shape[1])
1489
+
1490
+ special_tokens_mask = np.array(
1491
+ [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
1492
+ dtype=bool,
1493
+ )
1494
+ masked_indices[special_tokens_mask] = 0
1495
+ if self.tokenizer._pad_token is not None:
1496
+ padding_mask = labels == self.tokenizer.pad_token_id
1497
+ masked_indices[padding_mask] = 0.0
1498
+
1499
+ # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
1500
+ non_func_mask = ~(padding_mask | special_tokens_mask)
1501
+
1502
+ inputs[masked_indices] = self.tokenizer.mask_token_id
1503
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
1504
+
1505
+ perm_mask = np.zeros((labels.shape[0], labels.shape[1], labels.shape[1]), dtype=np.float32)
1506
+
1507
+ for i in range(labels.shape[0]):
1508
+ # Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
1509
+ # determine which tokens a given token can attend to (encoded in `perm_mask`).
1510
+ # Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
1511
+ # (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
1512
+ # we assume that reused length is half of sequence length and permutation length is equal to reused length.
1513
+ # This requires that the sequence length be even.
1514
+
1515
+ # Create a linear factorisation order
1516
+ perm_index = np.arange(labels.shape[1])
1517
+ # Split this into two halves, assuming that half the sequence is reused each time
1518
+ perm_index = perm_index.reshape((-1, labels.shape[1] // 2)).T
1519
+ # Permute the two halves such that they do not cross over
1520
+ np.random.shuffle(perm_index)
1521
+ # Flatten this out into the desired permuted factorisation order
1522
+ perm_index = perm_index.T.flatten()
1523
+ # Set the permutation indices of non-masked (non-functional) tokens to the
1524
+ # smallest index (-1) so that:
1525
+ # (1) They can be seen by all other positions
1526
+ # (2) They cannot see masked positions, so there won't be information leak
1527
+ perm_index[~masked_indices[i] & non_func_mask[i]] = -1
1528
+ # The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
1529
+ # 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
1530
+ # 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
1531
+ perm_mask[i] = (
1532
+ perm_index.reshape((labels.shape[1], 1)) <= perm_index.reshape((1, labels.shape[1]))
1533
+ ) & masked_indices[i]
1534
+
1535
+ return inputs.astype(np.int64), perm_mask, target_mapping, labels.astype(np.int64)
transformers_4_35_0/data/datasets/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .glue import GlueDataset, GlueDataTrainingArguments
16
+ from .language_modeling import (
17
+ LineByLineTextDataset,
18
+ LineByLineWithRefDataset,
19
+ LineByLineWithSOPTextDataset,
20
+ TextDataset,
21
+ TextDatasetForNextSentencePrediction,
22
+ )
23
+ from .squad import SquadDataset, SquadDataTrainingArguments
transformers_4_35_0/data/datasets/glue.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import time
17
+ import warnings
18
+ from dataclasses import dataclass, field
19
+ from enum import Enum
20
+ from typing import List, Optional, Union
21
+
22
+ import torch
23
+ from filelock import FileLock
24
+ from torch.utils.data import Dataset
25
+
26
+ from ...tokenization_utils_base import PreTrainedTokenizerBase
27
+ from ...utils import logging
28
+ from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
29
+ from ..processors.utils import InputFeatures
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ @dataclass
36
+ class GlueDataTrainingArguments:
37
+ """
38
+ Arguments pertaining to what data we are going to input our model for training and eval.
39
+
40
+ Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command
41
+ line.
42
+ """
43
+
44
+ task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
45
+ data_dir: str = field(
46
+ metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
47
+ )
48
+ max_seq_length: int = field(
49
+ default=128,
50
+ metadata={
51
+ "help": (
52
+ "The maximum total input sequence length after tokenization. Sequences longer "
53
+ "than this will be truncated, sequences shorter will be padded."
54
+ )
55
+ },
56
+ )
57
+ overwrite_cache: bool = field(
58
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
59
+ )
60
+
61
+ def __post_init__(self):
62
+ self.task_name = self.task_name.lower()
63
+
64
+
65
+ class Split(Enum):
66
+ train = "train"
67
+ dev = "dev"
68
+ test = "test"
69
+
70
+
71
+ class GlueDataset(Dataset):
72
+ """
73
+ This will be superseded by a framework-agnostic approach soon.
74
+ """
75
+
76
+ args: GlueDataTrainingArguments
77
+ output_mode: str
78
+ features: List[InputFeatures]
79
+
80
+ def __init__(
81
+ self,
82
+ args: GlueDataTrainingArguments,
83
+ tokenizer: PreTrainedTokenizerBase,
84
+ limit_length: Optional[int] = None,
85
+ mode: Union[str, Split] = Split.train,
86
+ cache_dir: Optional[str] = None,
87
+ ):
88
+ warnings.warn(
89
+ "This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
90
+ "library. You can have a look at this example script for pointers: "
91
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py",
92
+ FutureWarning,
93
+ )
94
+ self.args = args
95
+ self.processor = glue_processors[args.task_name]()
96
+ self.output_mode = glue_output_modes[args.task_name]
97
+ if isinstance(mode, str):
98
+ try:
99
+ mode = Split[mode]
100
+ except KeyError:
101
+ raise KeyError("mode is not a valid split name")
102
+ # Load data features from cache or dataset file
103
+ cached_features_file = os.path.join(
104
+ cache_dir if cache_dir is not None else args.data_dir,
105
+ f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{args.task_name}",
106
+ )
107
+ label_list = self.processor.get_labels()
108
+ if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__.__name__ in (
109
+ "RobertaTokenizer",
110
+ "RobertaTokenizerFast",
111
+ "XLMRobertaTokenizer",
112
+ "BartTokenizer",
113
+ "BartTokenizerFast",
114
+ ):
115
+ # HACK(label indices are swapped in RoBERTa pretrained model)
116
+ label_list[1], label_list[2] = label_list[2], label_list[1]
117
+ self.label_list = label_list
118
+
119
+ # Make sure only the first process in distributed training processes the dataset,
120
+ # and the others will use the cache.
121
+ lock_path = cached_features_file + ".lock"
122
+ with FileLock(lock_path):
123
+ if os.path.exists(cached_features_file) and not args.overwrite_cache:
124
+ start = time.time()
125
+ self.features = torch.load(cached_features_file)
126
+ logger.info(
127
+ f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
128
+ )
129
+ else:
130
+ logger.info(f"Creating features from dataset file at {args.data_dir}")
131
+
132
+ if mode == Split.dev:
133
+ examples = self.processor.get_dev_examples(args.data_dir)
134
+ elif mode == Split.test:
135
+ examples = self.processor.get_test_examples(args.data_dir)
136
+ else:
137
+ examples = self.processor.get_train_examples(args.data_dir)
138
+ if limit_length is not None:
139
+ examples = examples[:limit_length]
140
+ self.features = glue_convert_examples_to_features(
141
+ examples,
142
+ tokenizer,
143
+ max_length=args.max_seq_length,
144
+ label_list=label_list,
145
+ output_mode=self.output_mode,
146
+ )
147
+ start = time.time()
148
+ torch.save(self.features, cached_features_file)
149
+ # ^ This seems to take a lot of time so I want to investigate why and how we can improve.
150
+ logger.info(
151
+ f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
152
+ )
153
+
154
+ def __len__(self):
155
+ return len(self.features)
156
+
157
+ def __getitem__(self, i) -> InputFeatures:
158
+ return self.features[i]
159
+
160
+ def get_labels(self):
161
+ return self.label_list
transformers_4_35_0/data/datasets/language_modeling.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ import pickle
18
+ import random
19
+ import time
20
+ import warnings
21
+ from typing import Dict, List, Optional
22
+
23
+ import torch
24
+ from filelock import FileLock
25
+ from torch.utils.data import Dataset
26
+
27
+ from ...tokenization_utils import PreTrainedTokenizer
28
+ from ...utils import logging
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ DEPRECATION_WARNING = (
35
+ "This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
36
+ "library. You can have a look at this example script for pointers: {0}"
37
+ )
38
+
39
+
40
+ class TextDataset(Dataset):
41
+ """
42
+ This will be superseded by a framework-agnostic approach soon.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ tokenizer: PreTrainedTokenizer,
48
+ file_path: str,
49
+ block_size: int,
50
+ overwrite_cache=False,
51
+ cache_dir: Optional[str] = None,
52
+ ):
53
+ warnings.warn(
54
+ DEPRECATION_WARNING.format(
55
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
56
+ ),
57
+ FutureWarning,
58
+ )
59
+ if os.path.isfile(file_path) is False:
60
+ raise ValueError(f"Input file path {file_path} not found")
61
+
62
+ block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
63
+
64
+ directory, filename = os.path.split(file_path)
65
+ cached_features_file = os.path.join(
66
+ cache_dir if cache_dir is not None else directory,
67
+ f"cached_lm_{tokenizer.__class__.__name__}_{block_size}_{filename}",
68
+ )
69
+
70
+ # Make sure only the first process in distributed training processes the dataset,
71
+ # and the others will use the cache.
72
+ lock_path = cached_features_file + ".lock"
73
+ with FileLock(lock_path):
74
+ if os.path.exists(cached_features_file) and not overwrite_cache:
75
+ start = time.time()
76
+ with open(cached_features_file, "rb") as handle:
77
+ self.examples = pickle.load(handle)
78
+ logger.info(
79
+ f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
80
+ )
81
+
82
+ else:
83
+ logger.info(f"Creating features from dataset file at {directory}")
84
+
85
+ self.examples = []
86
+ with open(file_path, encoding="utf-8") as f:
87
+ text = f.read()
88
+
89
+ tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
90
+
91
+ for i in range(0, len(tokenized_text) - block_size + 1, block_size): # Truncate in block of block_size
92
+ self.examples.append(
93
+ tokenizer.build_inputs_with_special_tokens(tokenized_text[i : i + block_size])
94
+ )
95
+ # Note that we are losing the last truncated example here for the sake of simplicity (no padding)
96
+ # If your dataset is small, first you should look for a bigger one :-) and second you
97
+ # can change this behavior by adding (model specific) padding.
98
+
99
+ start = time.time()
100
+ with open(cached_features_file, "wb") as handle:
101
+ pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
102
+ logger.info(
103
+ f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
104
+ )
105
+
106
+ def __len__(self):
107
+ return len(self.examples)
108
+
109
+ def __getitem__(self, i) -> torch.Tensor:
110
+ return torch.tensor(self.examples[i], dtype=torch.long)
111
+
112
+
113
+ class LineByLineTextDataset(Dataset):
114
+ """
115
+ This will be superseded by a framework-agnostic approach soon.
116
+ """
117
+
118
+ def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
119
+ warnings.warn(
120
+ DEPRECATION_WARNING.format(
121
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
122
+ ),
123
+ FutureWarning,
124
+ )
125
+ if os.path.isfile(file_path) is False:
126
+ raise ValueError(f"Input file path {file_path} not found")
127
+ # Here, we do not cache the features, operating under the assumption
128
+ # that we will soon use fast multithreaded tokenizers from the
129
+ # `tokenizers` repo everywhere =)
130
+ logger.info(f"Creating features from dataset file at {file_path}")
131
+
132
+ with open(file_path, encoding="utf-8") as f:
133
+ lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
134
+
135
+ batch_encoding = tokenizer(lines, add_special_tokens=True, truncation=True, max_length=block_size)
136
+ self.examples = batch_encoding["input_ids"]
137
+ self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
138
+
139
+ def __len__(self):
140
+ return len(self.examples)
141
+
142
+ def __getitem__(self, i) -> Dict[str, torch.tensor]:
143
+ return self.examples[i]
144
+
145
+
146
+ class LineByLineWithRefDataset(Dataset):
147
+ """
148
+ This will be superseded by a framework-agnostic approach soon.
149
+ """
150
+
151
+ def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, ref_path: str):
152
+ warnings.warn(
153
+ DEPRECATION_WARNING.format(
154
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm_wwm.py"
155
+ ),
156
+ FutureWarning,
157
+ )
158
+ if os.path.isfile(file_path) is False:
159
+ raise ValueError(f"Input file path {file_path} not found")
160
+ if os.path.isfile(ref_path) is False:
161
+ raise ValueError(f"Ref file path {file_path} not found")
162
+ # Here, we do not cache the features, operating under the assumption
163
+ # that we will soon use fast multithreaded tokenizers from the
164
+ # `tokenizers` repo everywhere =)
165
+ logger.info(f"Creating features from dataset file at {file_path}")
166
+ logger.info(f"Use ref segment results at {ref_path}")
167
+ with open(file_path, encoding="utf-8") as f:
168
+ data = f.readlines() # use this method to avoid delimiter '\u2029' to split a line
169
+ data = [line.strip() for line in data if len(line) > 0 and not line.isspace()]
170
+ # Get ref inf from file
171
+ with open(ref_path, encoding="utf-8") as f:
172
+ ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
173
+ if len(data) != len(ref):
174
+ raise ValueError(
175
+ f"Length of Input file should be equal to Ref file. But the length of {file_path} is {len(data)} "
176
+ f"while length of {ref_path} is {len(ref)}"
177
+ )
178
+
179
+ batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size)
180
+ self.examples = batch_encoding["input_ids"]
181
+ self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
182
+
183
+ n = len(self.examples)
184
+ for i in range(n):
185
+ self.examples[i]["chinese_ref"] = torch.tensor(ref[i], dtype=torch.long)
186
+
187
+ def __len__(self):
188
+ return len(self.examples)
189
+
190
+ def __getitem__(self, i) -> Dict[str, torch.tensor]:
191
+ return self.examples[i]
192
+
193
+
194
+ class LineByLineWithSOPTextDataset(Dataset):
195
+ """
196
+ Dataset for sentence order prediction task, prepare sentence pairs for SOP task
197
+ """
198
+
199
+ def __init__(self, tokenizer: PreTrainedTokenizer, file_dir: str, block_size: int):
200
+ warnings.warn(
201
+ DEPRECATION_WARNING.format(
202
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
203
+ ),
204
+ FutureWarning,
205
+ )
206
+ if os.path.isdir(file_dir) is False:
207
+ raise ValueError(f"{file_dir} is not a directory")
208
+ logger.info(f"Creating features from dataset file folder at {file_dir}")
209
+ self.examples = []
210
+ # TODO: randomness could apply a random seed, ex. rng = random.Random(random_seed)
211
+ # file path looks like ./dataset/wiki_1, ./dataset/wiki_2
212
+ for file_name in os.listdir(file_dir):
213
+ file_path = os.path.join(file_dir, file_name)
214
+ if os.path.isfile(file_path) is False:
215
+ raise ValueError(f"{file_path} is not a file")
216
+ article_open = False
217
+ with open(file_path, encoding="utf-8") as f:
218
+ original_lines = f.readlines()
219
+ article_lines = []
220
+ for line in original_lines:
221
+ if "<doc id=" in line:
222
+ article_open = True
223
+ elif "</doc>" in line:
224
+ article_open = False
225
+ document = [
226
+ tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line))
227
+ for line in article_lines[1:]
228
+ if (len(line) > 0 and not line.isspace())
229
+ ]
230
+
231
+ examples = self.create_examples_from_document(document, block_size, tokenizer)
232
+ self.examples.extend(examples)
233
+ article_lines = []
234
+ else:
235
+ if article_open:
236
+ article_lines.append(line)
237
+
238
+ logger.info("Dataset parse finished.")
239
+
240
+ def create_examples_from_document(self, document, block_size, tokenizer, short_seq_prob=0.1):
241
+ """Creates examples for a single document."""
242
+
243
+ # Account for special tokens
244
+ max_num_tokens = block_size - tokenizer.num_special_tokens_to_add(pair=True)
245
+
246
+ # We *usually* want to fill up the entire sequence since we are padding
247
+ # to `block_size` anyways, so short sequences are generally wasted
248
+ # computation. However, we *sometimes*
249
+ # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
250
+ # sequences to minimize the mismatch between pretraining and fine-tuning.
251
+ # The `target_seq_length` is just a rough target however, whereas
252
+ # `block_size` is a hard limit.
253
+ target_seq_length = max_num_tokens
254
+ if random.random() < short_seq_prob:
255
+ target_seq_length = random.randint(2, max_num_tokens)
256
+
257
+ # We DON'T just concatenate all of the tokens from a document into a long
258
+ # sequence and choose an arbitrary split point because this would make the
259
+ # next sentence prediction task too easy. Instead, we split the input into
260
+ # segments "A" and "B" based on the actual "sentences" provided by the user
261
+ # input.
262
+ examples = []
263
+ current_chunk = [] # a buffer stored current working segments
264
+ current_length = 0
265
+ i = 0
266
+ while i < len(document):
267
+ segment = document[i] # get a segment
268
+ if not segment:
269
+ i += 1
270
+ continue
271
+ current_chunk.append(segment) # add a segment to current chunk
272
+ current_length += len(segment) # overall token length
273
+ # if current length goes to the target length or reaches the end of file, start building token a and b
274
+ if i == len(document) - 1 or current_length >= target_seq_length:
275
+ if current_chunk:
276
+ # `a_end` is how many segments from `current_chunk` go into the `A` (first) sentence.
277
+ a_end = 1
278
+ # if current chunk has more than 2 sentences, pick part of it `A` (first) sentence
279
+ if len(current_chunk) >= 2:
280
+ a_end = random.randint(1, len(current_chunk) - 1)
281
+ # token a
282
+ tokens_a = []
283
+ for j in range(a_end):
284
+ tokens_a.extend(current_chunk[j])
285
+
286
+ # token b
287
+ tokens_b = []
288
+ for j in range(a_end, len(current_chunk)):
289
+ tokens_b.extend(current_chunk[j])
290
+
291
+ if len(tokens_a) == 0 or len(tokens_b) == 0:
292
+ continue
293
+
294
+ # switch tokens_a and tokens_b randomly
295
+ if random.random() < 0.5:
296
+ is_next = False
297
+ tokens_a, tokens_b = tokens_b, tokens_a
298
+ else:
299
+ is_next = True
300
+
301
+ def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
302
+ """Truncates a pair of sequences to a maximum sequence length."""
303
+ while True:
304
+ total_length = len(tokens_a) + len(tokens_b)
305
+ if total_length <= max_num_tokens:
306
+ break
307
+ trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
308
+ if not (len(trunc_tokens) >= 1):
309
+ raise ValueError("Sequence length to be truncated must be no less than one")
310
+ # We want to sometimes truncate from the front and sometimes from the
311
+ # back to add more randomness and avoid biases.
312
+ if random.random() < 0.5:
313
+ del trunc_tokens[0]
314
+ else:
315
+ trunc_tokens.pop()
316
+
317
+ truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)
318
+ if not (len(tokens_a) >= 1):
319
+ raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
320
+ if not (len(tokens_b) >= 1):
321
+ raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
322
+
323
+ # add special tokens
324
+ input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
325
+ # add token type ids, 0 for sentence a, 1 for sentence b
326
+ token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
327
+
328
+ example = {
329
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
330
+ "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
331
+ "sentence_order_label": torch.tensor(0 if is_next else 1, dtype=torch.long),
332
+ }
333
+ examples.append(example)
334
+ current_chunk = [] # clear current chunk
335
+ current_length = 0 # reset current text length
336
+ i += 1 # go to next line
337
+ return examples
338
+
339
+ def __len__(self):
340
+ return len(self.examples)
341
+
342
+ def __getitem__(self, i) -> Dict[str, torch.tensor]:
343
+ return self.examples[i]
344
+
345
+
346
+ class TextDatasetForNextSentencePrediction(Dataset):
347
+ """
348
+ This will be superseded by a framework-agnostic approach soon.
349
+ """
350
+
351
+ def __init__(
352
+ self,
353
+ tokenizer: PreTrainedTokenizer,
354
+ file_path: str,
355
+ block_size: int,
356
+ overwrite_cache=False,
357
+ short_seq_probability=0.1,
358
+ nsp_probability=0.5,
359
+ ):
360
+ warnings.warn(
361
+ DEPRECATION_WARNING.format(
362
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
363
+ ),
364
+ FutureWarning,
365
+ )
366
+ if not os.path.isfile(file_path):
367
+ raise ValueError(f"Input file path {file_path} not found")
368
+
369
+ self.short_seq_probability = short_seq_probability
370
+ self.nsp_probability = nsp_probability
371
+
372
+ directory, filename = os.path.split(file_path)
373
+ cached_features_file = os.path.join(
374
+ directory,
375
+ f"cached_nsp_{tokenizer.__class__.__name__}_{block_size}_{filename}",
376
+ )
377
+
378
+ self.tokenizer = tokenizer
379
+
380
+ # Make sure only the first process in distributed training processes the dataset,
381
+ # and the others will use the cache.
382
+ lock_path = cached_features_file + ".lock"
383
+
384
+ # Input file format:
385
+ # (1) One sentence per line. These should ideally be actual sentences, not
386
+ # entire paragraphs or arbitrary spans of text. (Because we use the
387
+ # sentence boundaries for the "next sentence prediction" task).
388
+ # (2) Blank lines between documents. Document boundaries are needed so
389
+ # that the "next sentence prediction" task doesn't span between documents.
390
+ #
391
+ # Example:
392
+ # I am very happy.
393
+ # Here is the second sentence.
394
+ #
395
+ # A new document.
396
+
397
+ with FileLock(lock_path):
398
+ if os.path.exists(cached_features_file) and not overwrite_cache:
399
+ start = time.time()
400
+ with open(cached_features_file, "rb") as handle:
401
+ self.examples = pickle.load(handle)
402
+ logger.info(
403
+ f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
404
+ )
405
+ else:
406
+ logger.info(f"Creating features from dataset file at {directory}")
407
+
408
+ self.documents = [[]]
409
+ with open(file_path, encoding="utf-8") as f:
410
+ while True:
411
+ line = f.readline()
412
+ if not line:
413
+ break
414
+ line = line.strip()
415
+
416
+ # Empty lines are used as document delimiters
417
+ if not line and len(self.documents[-1]) != 0:
418
+ self.documents.append([])
419
+ tokens = tokenizer.tokenize(line)
420
+ tokens = tokenizer.convert_tokens_to_ids(tokens)
421
+ if tokens:
422
+ self.documents[-1].append(tokens)
423
+
424
+ logger.info(f"Creating examples from {len(self.documents)} documents.")
425
+ self.examples = []
426
+ for doc_index, document in enumerate(self.documents):
427
+ self.create_examples_from_document(document, doc_index, block_size)
428
+
429
+ start = time.time()
430
+ with open(cached_features_file, "wb") as handle:
431
+ pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
432
+ logger.info(
433
+ f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
434
+ )
435
+
436
+ def create_examples_from_document(self, document: List[List[int]], doc_index: int, block_size: int):
437
+ """Creates examples for a single document."""
438
+
439
+ max_num_tokens = block_size - self.tokenizer.num_special_tokens_to_add(pair=True)
440
+
441
+ # We *usually* want to fill up the entire sequence since we are padding
442
+ # to `block_size` anyways, so short sequences are generally wasted
443
+ # computation. However, we *sometimes*
444
+ # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
445
+ # sequences to minimize the mismatch between pretraining and fine-tuning.
446
+ # The `target_seq_length` is just a rough target however, whereas
447
+ # `block_size` is a hard limit.
448
+ target_seq_length = max_num_tokens
449
+ if random.random() < self.short_seq_probability:
450
+ target_seq_length = random.randint(2, max_num_tokens)
451
+
452
+ current_chunk = [] # a buffer stored current working segments
453
+ current_length = 0
454
+ i = 0
455
+
456
+ while i < len(document):
457
+ segment = document[i]
458
+ current_chunk.append(segment)
459
+ current_length += len(segment)
460
+ if i == len(document) - 1 or current_length >= target_seq_length:
461
+ if current_chunk:
462
+ # `a_end` is how many segments from `current_chunk` go into the `A`
463
+ # (first) sentence.
464
+ a_end = 1
465
+ if len(current_chunk) >= 2:
466
+ a_end = random.randint(1, len(current_chunk) - 1)
467
+
468
+ tokens_a = []
469
+ for j in range(a_end):
470
+ tokens_a.extend(current_chunk[j])
471
+
472
+ tokens_b = []
473
+
474
+ if len(current_chunk) == 1 or random.random() < self.nsp_probability:
475
+ is_random_next = True
476
+ target_b_length = target_seq_length - len(tokens_a)
477
+
478
+ # This should rarely go for more than one iteration for large
479
+ # corpora. However, just to be careful, we try to make sure that
480
+ # the random document is not the same as the document
481
+ # we're processing.
482
+ for _ in range(10):
483
+ random_document_index = random.randint(0, len(self.documents) - 1)
484
+ if random_document_index != doc_index:
485
+ break
486
+
487
+ random_document = self.documents[random_document_index]
488
+ random_start = random.randint(0, len(random_document) - 1)
489
+ for j in range(random_start, len(random_document)):
490
+ tokens_b.extend(random_document[j])
491
+ if len(tokens_b) >= target_b_length:
492
+ break
493
+ # We didn't actually use these segments so we "put them back" so
494
+ # they don't go to waste.
495
+ num_unused_segments = len(current_chunk) - a_end
496
+ i -= num_unused_segments
497
+ # Actual next
498
+ else:
499
+ is_random_next = False
500
+ for j in range(a_end, len(current_chunk)):
501
+ tokens_b.extend(current_chunk[j])
502
+
503
+ if not (len(tokens_a) >= 1):
504
+ raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
505
+ if not (len(tokens_b) >= 1):
506
+ raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
507
+
508
+ # add special tokens
509
+ input_ids = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
510
+ # add token type ids, 0 for sentence a, 1 for sentence b
511
+ token_type_ids = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
512
+
513
+ example = {
514
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
515
+ "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
516
+ "next_sentence_label": torch.tensor(1 if is_random_next else 0, dtype=torch.long),
517
+ }
518
+
519
+ self.examples.append(example)
520
+
521
+ current_chunk = []
522
+ current_length = 0
523
+
524
+ i += 1
525
+
526
+ def __len__(self):
527
+ return len(self.examples)
528
+
529
+ def __getitem__(self, i):
530
+ return self.examples[i]
transformers_4_35_0/data/datasets/squad.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import time
17
+ from dataclasses import dataclass, field
18
+ from enum import Enum
19
+ from typing import Dict, List, Optional, Union
20
+
21
+ import torch
22
+ from filelock import FileLock
23
+ from torch.utils.data import Dataset
24
+
25
+ from ...models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
26
+ from ...tokenization_utils import PreTrainedTokenizer
27
+ from ...utils import logging
28
+ from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
34
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
35
+
36
+
37
+ @dataclass
38
+ class SquadDataTrainingArguments:
39
+ """
40
+ Arguments pertaining to what data we are going to input our model for training and eval.
41
+ """
42
+
43
+ model_type: str = field(
44
+ default=None, metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_TYPES)}
45
+ )
46
+ data_dir: str = field(
47
+ default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
48
+ )
49
+ max_seq_length: int = field(
50
+ default=128,
51
+ metadata={
52
+ "help": (
53
+ "The maximum total input sequence length after tokenization. Sequences longer "
54
+ "than this will be truncated, sequences shorter will be padded."
55
+ )
56
+ },
57
+ )
58
+ doc_stride: int = field(
59
+ default=128,
60
+ metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
61
+ )
62
+ max_query_length: int = field(
63
+ default=64,
64
+ metadata={
65
+ "help": (
66
+ "The maximum number of tokens for the question. Questions longer than this will "
67
+ "be truncated to this length."
68
+ )
69
+ },
70
+ )
71
+ max_answer_length: int = field(
72
+ default=30,
73
+ metadata={
74
+ "help": (
75
+ "The maximum length of an answer that can be generated. This is needed because the start "
76
+ "and end predictions are not conditioned on one another."
77
+ )
78
+ },
79
+ )
80
+ overwrite_cache: bool = field(
81
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
82
+ )
83
+ version_2_with_negative: bool = field(
84
+ default=False, metadata={"help": "If true, the SQuAD examples contain some that do not have an answer."}
85
+ )
86
+ null_score_diff_threshold: float = field(
87
+ default=0.0, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
88
+ )
89
+ n_best_size: int = field(
90
+ default=20, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
91
+ )
92
+ lang_id: int = field(
93
+ default=0,
94
+ metadata={
95
+ "help": (
96
+ "language id of input for language-specific xlm models (see"
97
+ " tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
98
+ )
99
+ },
100
+ )
101
+ threads: int = field(default=1, metadata={"help": "multiple threads for converting example to features"})
102
+
103
+
104
+ class Split(Enum):
105
+ train = "train"
106
+ dev = "dev"
107
+
108
+
109
+ class SquadDataset(Dataset):
110
+ """
111
+ This will be superseded by a framework-agnostic approach soon.
112
+ """
113
+
114
+ args: SquadDataTrainingArguments
115
+ features: List[SquadFeatures]
116
+ mode: Split
117
+ is_language_sensitive: bool
118
+
119
+ def __init__(
120
+ self,
121
+ args: SquadDataTrainingArguments,
122
+ tokenizer: PreTrainedTokenizer,
123
+ limit_length: Optional[int] = None,
124
+ mode: Union[str, Split] = Split.train,
125
+ is_language_sensitive: Optional[bool] = False,
126
+ cache_dir: Optional[str] = None,
127
+ dataset_format: Optional[str] = "pt",
128
+ ):
129
+ self.args = args
130
+ self.is_language_sensitive = is_language_sensitive
131
+ self.processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
132
+ if isinstance(mode, str):
133
+ try:
134
+ mode = Split[mode]
135
+ except KeyError:
136
+ raise KeyError("mode is not a valid split name")
137
+ self.mode = mode
138
+ # Load data features from cache or dataset file
139
+ version_tag = "v2" if args.version_2_with_negative else "v1"
140
+ cached_features_file = os.path.join(
141
+ cache_dir if cache_dir is not None else args.data_dir,
142
+ f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{version_tag}",
143
+ )
144
+
145
+ # Make sure only the first process in distributed training processes the dataset,
146
+ # and the others will use the cache.
147
+ lock_path = cached_features_file + ".lock"
148
+ with FileLock(lock_path):
149
+ if os.path.exists(cached_features_file) and not args.overwrite_cache:
150
+ start = time.time()
151
+ self.old_features = torch.load(cached_features_file)
152
+
153
+ # Legacy cache files have only features, while new cache files
154
+ # will have dataset and examples also.
155
+ self.features = self.old_features["features"]
156
+ self.dataset = self.old_features.get("dataset", None)
157
+ self.examples = self.old_features.get("examples", None)
158
+ logger.info(
159
+ f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
160
+ )
161
+
162
+ if self.dataset is None or self.examples is None:
163
+ logger.warning(
164
+ f"Deleting cached file {cached_features_file} will allow dataset and examples to be cached in"
165
+ " future run"
166
+ )
167
+ else:
168
+ if mode == Split.dev:
169
+ self.examples = self.processor.get_dev_examples(args.data_dir)
170
+ else:
171
+ self.examples = self.processor.get_train_examples(args.data_dir)
172
+
173
+ self.features, self.dataset = squad_convert_examples_to_features(
174
+ examples=self.examples,
175
+ tokenizer=tokenizer,
176
+ max_seq_length=args.max_seq_length,
177
+ doc_stride=args.doc_stride,
178
+ max_query_length=args.max_query_length,
179
+ is_training=mode == Split.train,
180
+ threads=args.threads,
181
+ return_dataset=dataset_format,
182
+ )
183
+
184
+ start = time.time()
185
+ torch.save(
186
+ {"features": self.features, "dataset": self.dataset, "examples": self.examples},
187
+ cached_features_file,
188
+ )
189
+ # ^ This seems to take a lot of time so I want to investigate why and how we can improve.
190
+ logger.info(
191
+ f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
192
+ )
193
+
194
+ def __len__(self):
195
+ return len(self.features)
196
+
197
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
198
+ # Convert to Tensors and build dataset
199
+ feature = self.features[i]
200
+
201
+ input_ids = torch.tensor(feature.input_ids, dtype=torch.long)
202
+ attention_mask = torch.tensor(feature.attention_mask, dtype=torch.long)
203
+ token_type_ids = torch.tensor(feature.token_type_ids, dtype=torch.long)
204
+ cls_index = torch.tensor(feature.cls_index, dtype=torch.long)
205
+ p_mask = torch.tensor(feature.p_mask, dtype=torch.float)
206
+ is_impossible = torch.tensor(feature.is_impossible, dtype=torch.float)
207
+
208
+ inputs = {
209
+ "input_ids": input_ids,
210
+ "attention_mask": attention_mask,
211
+ "token_type_ids": token_type_ids,
212
+ }
213
+
214
+ if self.args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
215
+ del inputs["token_type_ids"]
216
+
217
+ if self.args.model_type in ["xlnet", "xlm"]:
218
+ inputs.update({"cls_index": cls_index, "p_mask": p_mask})
219
+ if self.args.version_2_with_negative:
220
+ inputs.update({"is_impossible": is_impossible})
221
+ if self.is_language_sensitive:
222
+ inputs.update({"langs": (torch.ones(input_ids.shape, dtype=torch.int64) * self.args.lang_id)})
223
+
224
+ if self.mode == Split.train:
225
+ start_positions = torch.tensor(feature.start_position, dtype=torch.long)
226
+ end_positions = torch.tensor(feature.end_position, dtype=torch.long)
227
+ inputs.update({"start_positions": start_positions, "end_positions": end_positions})
228
+
229
+ return inputs
transformers_4_35_0/data/metrics/__init__.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ import warnings
14
+
15
+ from ...utils import is_sklearn_available, requires_backends
16
+
17
+
18
+ if is_sklearn_available():
19
+ from scipy.stats import pearsonr, spearmanr
20
+ from sklearn.metrics import f1_score, matthews_corrcoef
21
+
22
+
23
+ DEPRECATION_WARNING = (
24
+ "This metric will be removed from the library soon, metrics should be handled with the 🤗 Evaluate "
25
+ "library. You can have a look at this example script for pointers: "
26
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
27
+ )
28
+
29
+
30
+ def simple_accuracy(preds, labels):
31
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
32
+ requires_backends(simple_accuracy, "sklearn")
33
+ return (preds == labels).mean()
34
+
35
+
36
+ def acc_and_f1(preds, labels):
37
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
38
+ requires_backends(acc_and_f1, "sklearn")
39
+ acc = simple_accuracy(preds, labels)
40
+ f1 = f1_score(y_true=labels, y_pred=preds)
41
+ return {
42
+ "acc": acc,
43
+ "f1": f1,
44
+ "acc_and_f1": (acc + f1) / 2,
45
+ }
46
+
47
+
48
+ def pearson_and_spearman(preds, labels):
49
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
50
+ requires_backends(pearson_and_spearman, "sklearn")
51
+ pearson_corr = pearsonr(preds, labels)[0]
52
+ spearman_corr = spearmanr(preds, labels)[0]
53
+ return {
54
+ "pearson": pearson_corr,
55
+ "spearmanr": spearman_corr,
56
+ "corr": (pearson_corr + spearman_corr) / 2,
57
+ }
58
+
59
+
60
+ def glue_compute_metrics(task_name, preds, labels):
61
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
62
+ requires_backends(glue_compute_metrics, "sklearn")
63
+ assert len(preds) == len(labels), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
64
+ if task_name == "cola":
65
+ return {"mcc": matthews_corrcoef(labels, preds)}
66
+ elif task_name == "sst-2":
67
+ return {"acc": simple_accuracy(preds, labels)}
68
+ elif task_name == "mrpc":
69
+ return acc_and_f1(preds, labels)
70
+ elif task_name == "sts-b":
71
+ return pearson_and_spearman(preds, labels)
72
+ elif task_name == "qqp":
73
+ return acc_and_f1(preds, labels)
74
+ elif task_name == "mnli":
75
+ return {"mnli/acc": simple_accuracy(preds, labels)}
76
+ elif task_name == "mnli-mm":
77
+ return {"mnli-mm/acc": simple_accuracy(preds, labels)}
78
+ elif task_name == "qnli":
79
+ return {"acc": simple_accuracy(preds, labels)}
80
+ elif task_name == "rte":
81
+ return {"acc": simple_accuracy(preds, labels)}
82
+ elif task_name == "wnli":
83
+ return {"acc": simple_accuracy(preds, labels)}
84
+ elif task_name == "hans":
85
+ return {"acc": simple_accuracy(preds, labels)}
86
+ else:
87
+ raise KeyError(task_name)
88
+
89
+
90
+ def xnli_compute_metrics(task_name, preds, labels):
91
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
92
+ requires_backends(xnli_compute_metrics, "sklearn")
93
+ if len(preds) != len(labels):
94
+ raise ValueError(f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}")
95
+ if task_name == "xnli":
96
+ return {"acc": simple_accuracy(preds, labels)}
97
+ else:
98
+ raise KeyError(task_name)
transformers_4_35_0/data/metrics/squad_metrics.py ADDED
@@ -0,0 +1,780 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Very heavily inspired by the official evaluation script for SQuAD version 2.0 which was modified by XLNet authors to
16
+ update `find_best_threshold` scripts for SQuAD V2.0
17
+
18
+ In addition to basic functionality, we also compute additional statistics and plot precision-recall curves if an
19
+ additional na_prob.json file is provided. This file is expected to map question ID's to the model's predicted
20
+ probability that a question is unanswerable.
21
+ """
22
+
23
+
24
+ import collections
25
+ import json
26
+ import math
27
+ import re
28
+ import string
29
+
30
+ from ...models.bert import BasicTokenizer
31
+ from ...utils import logging
32
+
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+
37
+ def normalize_answer(s):
38
+ """Lower text and remove punctuation, articles and extra whitespace."""
39
+
40
+ def remove_articles(text):
41
+ regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
42
+ return re.sub(regex, " ", text)
43
+
44
+ def white_space_fix(text):
45
+ return " ".join(text.split())
46
+
47
+ def remove_punc(text):
48
+ exclude = set(string.punctuation)
49
+ return "".join(ch for ch in text if ch not in exclude)
50
+
51
+ def lower(text):
52
+ return text.lower()
53
+
54
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
55
+
56
+
57
+ def get_tokens(s):
58
+ if not s:
59
+ return []
60
+ return normalize_answer(s).split()
61
+
62
+
63
+ def compute_exact(a_gold, a_pred):
64
+ return int(normalize_answer(a_gold) == normalize_answer(a_pred))
65
+
66
+
67
+ def compute_f1(a_gold, a_pred):
68
+ gold_toks = get_tokens(a_gold)
69
+ pred_toks = get_tokens(a_pred)
70
+ common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
71
+ num_same = sum(common.values())
72
+ if len(gold_toks) == 0 or len(pred_toks) == 0:
73
+ # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
74
+ return int(gold_toks == pred_toks)
75
+ if num_same == 0:
76
+ return 0
77
+ precision = 1.0 * num_same / len(pred_toks)
78
+ recall = 1.0 * num_same / len(gold_toks)
79
+ f1 = (2 * precision * recall) / (precision + recall)
80
+ return f1
81
+
82
+
83
+ def get_raw_scores(examples, preds):
84
+ """
85
+ Computes the exact and f1 scores from the examples and the model predictions
86
+ """
87
+ exact_scores = {}
88
+ f1_scores = {}
89
+
90
+ for example in examples:
91
+ qas_id = example.qas_id
92
+ gold_answers = [answer["text"] for answer in example.answers if normalize_answer(answer["text"])]
93
+
94
+ if not gold_answers:
95
+ # For unanswerable questions, only correct answer is empty string
96
+ gold_answers = [""]
97
+
98
+ if qas_id not in preds:
99
+ print(f"Missing prediction for {qas_id}")
100
+ continue
101
+
102
+ prediction = preds[qas_id]
103
+ exact_scores[qas_id] = max(compute_exact(a, prediction) for a in gold_answers)
104
+ f1_scores[qas_id] = max(compute_f1(a, prediction) for a in gold_answers)
105
+
106
+ return exact_scores, f1_scores
107
+
108
+
109
+ def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
110
+ new_scores = {}
111
+ for qid, s in scores.items():
112
+ pred_na = na_probs[qid] > na_prob_thresh
113
+ if pred_na:
114
+ new_scores[qid] = float(not qid_to_has_ans[qid])
115
+ else:
116
+ new_scores[qid] = s
117
+ return new_scores
118
+
119
+
120
+ def make_eval_dict(exact_scores, f1_scores, qid_list=None):
121
+ if not qid_list:
122
+ total = len(exact_scores)
123
+ return collections.OrderedDict(
124
+ [
125
+ ("exact", 100.0 * sum(exact_scores.values()) / total),
126
+ ("f1", 100.0 * sum(f1_scores.values()) / total),
127
+ ("total", total),
128
+ ]
129
+ )
130
+ else:
131
+ total = len(qid_list)
132
+ return collections.OrderedDict(
133
+ [
134
+ ("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
135
+ ("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
136
+ ("total", total),
137
+ ]
138
+ )
139
+
140
+
141
+ def merge_eval(main_eval, new_eval, prefix):
142
+ for k in new_eval:
143
+ main_eval[f"{prefix}_{k}"] = new_eval[k]
144
+
145
+
146
+ def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
147
+ num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
148
+ cur_score = num_no_ans
149
+ best_score = cur_score
150
+ best_thresh = 0.0
151
+ qid_list = sorted(na_probs, key=lambda k: na_probs[k])
152
+ for i, qid in enumerate(qid_list):
153
+ if qid not in scores:
154
+ continue
155
+ if qid_to_has_ans[qid]:
156
+ diff = scores[qid]
157
+ else:
158
+ if preds[qid]:
159
+ diff = -1
160
+ else:
161
+ diff = 0
162
+ cur_score += diff
163
+ if cur_score > best_score:
164
+ best_score = cur_score
165
+ best_thresh = na_probs[qid]
166
+
167
+ has_ans_score, has_ans_cnt = 0, 0
168
+ for qid in qid_list:
169
+ if not qid_to_has_ans[qid]:
170
+ continue
171
+ has_ans_cnt += 1
172
+
173
+ if qid not in scores:
174
+ continue
175
+ has_ans_score += scores[qid]
176
+
177
+ return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
178
+
179
+
180
+ def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
181
+ best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
182
+ best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
183
+ main_eval["best_exact"] = best_exact
184
+ main_eval["best_exact_thresh"] = exact_thresh
185
+ main_eval["best_f1"] = best_f1
186
+ main_eval["best_f1_thresh"] = f1_thresh
187
+ main_eval["has_ans_exact"] = has_ans_exact
188
+ main_eval["has_ans_f1"] = has_ans_f1
189
+
190
+
191
+ def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
192
+ num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
193
+ cur_score = num_no_ans
194
+ best_score = cur_score
195
+ best_thresh = 0.0
196
+ qid_list = sorted(na_probs, key=lambda k: na_probs[k])
197
+ for _, qid in enumerate(qid_list):
198
+ if qid not in scores:
199
+ continue
200
+ if qid_to_has_ans[qid]:
201
+ diff = scores[qid]
202
+ else:
203
+ if preds[qid]:
204
+ diff = -1
205
+ else:
206
+ diff = 0
207
+ cur_score += diff
208
+ if cur_score > best_score:
209
+ best_score = cur_score
210
+ best_thresh = na_probs[qid]
211
+ return 100.0 * best_score / len(scores), best_thresh
212
+
213
+
214
+ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
215
+ best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
216
+ best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
217
+
218
+ main_eval["best_exact"] = best_exact
219
+ main_eval["best_exact_thresh"] = exact_thresh
220
+ main_eval["best_f1"] = best_f1
221
+ main_eval["best_f1_thresh"] = f1_thresh
222
+
223
+
224
+ def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0):
225
+ qas_id_to_has_answer = {example.qas_id: bool(example.answers) for example in examples}
226
+ has_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if has_answer]
227
+ no_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if not has_answer]
228
+
229
+ if no_answer_probs is None:
230
+ no_answer_probs = {k: 0.0 for k in preds}
231
+
232
+ exact, f1 = get_raw_scores(examples, preds)
233
+
234
+ exact_threshold = apply_no_ans_threshold(
235
+ exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold
236
+ )
237
+ f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
238
+
239
+ evaluation = make_eval_dict(exact_threshold, f1_threshold)
240
+
241
+ if has_answer_qids:
242
+ has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids)
243
+ merge_eval(evaluation, has_ans_eval, "HasAns")
244
+
245
+ if no_answer_qids:
246
+ no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids)
247
+ merge_eval(evaluation, no_ans_eval, "NoAns")
248
+
249
+ if no_answer_probs:
250
+ find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer)
251
+
252
+ return evaluation
253
+
254
+
255
+ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
256
+ """Project the tokenized prediction back to the original text."""
257
+
258
+ # When we created the data, we kept track of the alignment between original
259
+ # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
260
+ # now `orig_text` contains the span of our original text corresponding to the
261
+ # span that we predicted.
262
+ #
263
+ # However, `orig_text` may contain extra characters that we don't want in
264
+ # our prediction.
265
+ #
266
+ # For example, let's say:
267
+ # pred_text = steve smith
268
+ # orig_text = Steve Smith's
269
+ #
270
+ # We don't want to return `orig_text` because it contains the extra "'s".
271
+ #
272
+ # We don't want to return `pred_text` because it's already been normalized
273
+ # (the SQuAD eval script also does punctuation stripping/lower casing but
274
+ # our tokenizer does additional normalization like stripping accent
275
+ # characters).
276
+ #
277
+ # What we really want to return is "Steve Smith".
278
+ #
279
+ # Therefore, we have to apply a semi-complicated alignment heuristic between
280
+ # `pred_text` and `orig_text` to get a character-to-character alignment. This
281
+ # can fail in certain cases in which case we just return `orig_text`.
282
+
283
+ def _strip_spaces(text):
284
+ ns_chars = []
285
+ ns_to_s_map = collections.OrderedDict()
286
+ for i, c in enumerate(text):
287
+ if c == " ":
288
+ continue
289
+ ns_to_s_map[len(ns_chars)] = i
290
+ ns_chars.append(c)
291
+ ns_text = "".join(ns_chars)
292
+ return (ns_text, ns_to_s_map)
293
+
294
+ # We first tokenize `orig_text`, strip whitespace from the result
295
+ # and `pred_text`, and check if they are the same length. If they are
296
+ # NOT the same length, the heuristic has failed. If they are the same
297
+ # length, we assume the characters are one-to-one aligned.
298
+ tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
299
+
300
+ tok_text = " ".join(tokenizer.tokenize(orig_text))
301
+
302
+ start_position = tok_text.find(pred_text)
303
+ if start_position == -1:
304
+ if verbose_logging:
305
+ logger.info(f"Unable to find text: '{pred_text}' in '{orig_text}'")
306
+ return orig_text
307
+ end_position = start_position + len(pred_text) - 1
308
+
309
+ (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
310
+ (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
311
+
312
+ if len(orig_ns_text) != len(tok_ns_text):
313
+ if verbose_logging:
314
+ logger.info(f"Length not equal after stripping spaces: '{orig_ns_text}' vs '{tok_ns_text}'")
315
+ return orig_text
316
+
317
+ # We then project the characters in `pred_text` back to `orig_text` using
318
+ # the character-to-character alignment.
319
+ tok_s_to_ns_map = {}
320
+ for i, tok_index in tok_ns_to_s_map.items():
321
+ tok_s_to_ns_map[tok_index] = i
322
+
323
+ orig_start_position = None
324
+ if start_position in tok_s_to_ns_map:
325
+ ns_start_position = tok_s_to_ns_map[start_position]
326
+ if ns_start_position in orig_ns_to_s_map:
327
+ orig_start_position = orig_ns_to_s_map[ns_start_position]
328
+
329
+ if orig_start_position is None:
330
+ if verbose_logging:
331
+ logger.info("Couldn't map start position")
332
+ return orig_text
333
+
334
+ orig_end_position = None
335
+ if end_position in tok_s_to_ns_map:
336
+ ns_end_position = tok_s_to_ns_map[end_position]
337
+ if ns_end_position in orig_ns_to_s_map:
338
+ orig_end_position = orig_ns_to_s_map[ns_end_position]
339
+
340
+ if orig_end_position is None:
341
+ if verbose_logging:
342
+ logger.info("Couldn't map end position")
343
+ return orig_text
344
+
345
+ output_text = orig_text[orig_start_position : (orig_end_position + 1)]
346
+ return output_text
347
+
348
+
349
+ def _get_best_indexes(logits, n_best_size):
350
+ """Get the n-best logits from a list."""
351
+ index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
352
+
353
+ best_indexes = []
354
+ for i in range(len(index_and_score)):
355
+ if i >= n_best_size:
356
+ break
357
+ best_indexes.append(index_and_score[i][0])
358
+ return best_indexes
359
+
360
+
361
+ def _compute_softmax(scores):
362
+ """Compute softmax probability over raw logits."""
363
+ if not scores:
364
+ return []
365
+
366
+ max_score = None
367
+ for score in scores:
368
+ if max_score is None or score > max_score:
369
+ max_score = score
370
+
371
+ exp_scores = []
372
+ total_sum = 0.0
373
+ for score in scores:
374
+ x = math.exp(score - max_score)
375
+ exp_scores.append(x)
376
+ total_sum += x
377
+
378
+ probs = []
379
+ for score in exp_scores:
380
+ probs.append(score / total_sum)
381
+ return probs
382
+
383
+
384
+ def compute_predictions_logits(
385
+ all_examples,
386
+ all_features,
387
+ all_results,
388
+ n_best_size,
389
+ max_answer_length,
390
+ do_lower_case,
391
+ output_prediction_file,
392
+ output_nbest_file,
393
+ output_null_log_odds_file,
394
+ verbose_logging,
395
+ version_2_with_negative,
396
+ null_score_diff_threshold,
397
+ tokenizer,
398
+ ):
399
+ """Write final predictions to the json file and log-odds of null if needed."""
400
+ if output_prediction_file:
401
+ logger.info(f"Writing predictions to: {output_prediction_file}")
402
+ if output_nbest_file:
403
+ logger.info(f"Writing nbest to: {output_nbest_file}")
404
+ if output_null_log_odds_file and version_2_with_negative:
405
+ logger.info(f"Writing null_log_odds to: {output_null_log_odds_file}")
406
+
407
+ example_index_to_features = collections.defaultdict(list)
408
+ for feature in all_features:
409
+ example_index_to_features[feature.example_index].append(feature)
410
+
411
+ unique_id_to_result = {}
412
+ for result in all_results:
413
+ unique_id_to_result[result.unique_id] = result
414
+
415
+ _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
416
+ "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]
417
+ )
418
+
419
+ all_predictions = collections.OrderedDict()
420
+ all_nbest_json = collections.OrderedDict()
421
+ scores_diff_json = collections.OrderedDict()
422
+
423
+ for example_index, example in enumerate(all_examples):
424
+ features = example_index_to_features[example_index]
425
+
426
+ prelim_predictions = []
427
+ # keep track of the minimum score of null start+end of position 0
428
+ score_null = 1000000 # large and positive
429
+ min_null_feature_index = 0 # the paragraph slice with min null score
430
+ null_start_logit = 0 # the start logit at the slice with min null score
431
+ null_end_logit = 0 # the end logit at the slice with min null score
432
+ for feature_index, feature in enumerate(features):
433
+ result = unique_id_to_result[feature.unique_id]
434
+ start_indexes = _get_best_indexes(result.start_logits, n_best_size)
435
+ end_indexes = _get_best_indexes(result.end_logits, n_best_size)
436
+ # if we could have irrelevant answers, get the min score of irrelevant
437
+ if version_2_with_negative:
438
+ feature_null_score = result.start_logits[0] + result.end_logits[0]
439
+ if feature_null_score < score_null:
440
+ score_null = feature_null_score
441
+ min_null_feature_index = feature_index
442
+ null_start_logit = result.start_logits[0]
443
+ null_end_logit = result.end_logits[0]
444
+ for start_index in start_indexes:
445
+ for end_index in end_indexes:
446
+ # We could hypothetically create invalid predictions, e.g., predict
447
+ # that the start of the span is in the question. We throw out all
448
+ # invalid predictions.
449
+ if start_index >= len(feature.tokens):
450
+ continue
451
+ if end_index >= len(feature.tokens):
452
+ continue
453
+ if start_index not in feature.token_to_orig_map:
454
+ continue
455
+ if end_index not in feature.token_to_orig_map:
456
+ continue
457
+ if not feature.token_is_max_context.get(start_index, False):
458
+ continue
459
+ if end_index < start_index:
460
+ continue
461
+ length = end_index - start_index + 1
462
+ if length > max_answer_length:
463
+ continue
464
+ prelim_predictions.append(
465
+ _PrelimPrediction(
466
+ feature_index=feature_index,
467
+ start_index=start_index,
468
+ end_index=end_index,
469
+ start_logit=result.start_logits[start_index],
470
+ end_logit=result.end_logits[end_index],
471
+ )
472
+ )
473
+ if version_2_with_negative:
474
+ prelim_predictions.append(
475
+ _PrelimPrediction(
476
+ feature_index=min_null_feature_index,
477
+ start_index=0,
478
+ end_index=0,
479
+ start_logit=null_start_logit,
480
+ end_logit=null_end_logit,
481
+ )
482
+ )
483
+ prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
484
+
485
+ _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
486
+ "NbestPrediction", ["text", "start_logit", "end_logit"]
487
+ )
488
+
489
+ seen_predictions = {}
490
+ nbest = []
491
+ for pred in prelim_predictions:
492
+ if len(nbest) >= n_best_size:
493
+ break
494
+ feature = features[pred.feature_index]
495
+ if pred.start_index > 0: # this is a non-null prediction
496
+ tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
497
+ orig_doc_start = feature.token_to_orig_map[pred.start_index]
498
+ orig_doc_end = feature.token_to_orig_map[pred.end_index]
499
+ orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
500
+
501
+ tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
502
+
503
+ # tok_text = " ".join(tok_tokens)
504
+ #
505
+ # # De-tokenize WordPieces that have been split off.
506
+ # tok_text = tok_text.replace(" ##", "")
507
+ # tok_text = tok_text.replace("##", "")
508
+
509
+ # Clean whitespace
510
+ tok_text = tok_text.strip()
511
+ tok_text = " ".join(tok_text.split())
512
+ orig_text = " ".join(orig_tokens)
513
+
514
+ final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
515
+ if final_text in seen_predictions:
516
+ continue
517
+
518
+ seen_predictions[final_text] = True
519
+ else:
520
+ final_text = ""
521
+ seen_predictions[final_text] = True
522
+
523
+ nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit))
524
+ # if we didn't include the empty option in the n-best, include it
525
+ if version_2_with_negative:
526
+ if "" not in seen_predictions:
527
+ nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit))
528
+
529
+ # In very rare edge cases we could only have single null prediction.
530
+ # So we just create a nonce prediction in this case to avoid failure.
531
+ if len(nbest) == 1:
532
+ nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
533
+
534
+ # In very rare edge cases we could have no valid predictions. So we
535
+ # just create a nonce prediction in this case to avoid failure.
536
+ if not nbest:
537
+ nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
538
+
539
+ if len(nbest) < 1:
540
+ raise ValueError("No valid predictions")
541
+
542
+ total_scores = []
543
+ best_non_null_entry = None
544
+ for entry in nbest:
545
+ total_scores.append(entry.start_logit + entry.end_logit)
546
+ if not best_non_null_entry:
547
+ if entry.text:
548
+ best_non_null_entry = entry
549
+
550
+ probs = _compute_softmax(total_scores)
551
+
552
+ nbest_json = []
553
+ for i, entry in enumerate(nbest):
554
+ output = collections.OrderedDict()
555
+ output["text"] = entry.text
556
+ output["probability"] = probs[i]
557
+ output["start_logit"] = entry.start_logit
558
+ output["end_logit"] = entry.end_logit
559
+ nbest_json.append(output)
560
+
561
+ if len(nbest_json) < 1:
562
+ raise ValueError("No valid predictions")
563
+
564
+ if not version_2_with_negative:
565
+ all_predictions[example.qas_id] = nbest_json[0]["text"]
566
+ else:
567
+ # predict "" iff the null score - the score of best non-null > threshold
568
+ score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit)
569
+ scores_diff_json[example.qas_id] = score_diff
570
+ if score_diff > null_score_diff_threshold:
571
+ all_predictions[example.qas_id] = ""
572
+ else:
573
+ all_predictions[example.qas_id] = best_non_null_entry.text
574
+ all_nbest_json[example.qas_id] = nbest_json
575
+
576
+ if output_prediction_file:
577
+ with open(output_prediction_file, "w") as writer:
578
+ writer.write(json.dumps(all_predictions, indent=4) + "\n")
579
+
580
+ if output_nbest_file:
581
+ with open(output_nbest_file, "w") as writer:
582
+ writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
583
+
584
+ if output_null_log_odds_file and version_2_with_negative:
585
+ with open(output_null_log_odds_file, "w") as writer:
586
+ writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
587
+
588
+ return all_predictions
589
+
590
+
591
+ def compute_predictions_log_probs(
592
+ all_examples,
593
+ all_features,
594
+ all_results,
595
+ n_best_size,
596
+ max_answer_length,
597
+ output_prediction_file,
598
+ output_nbest_file,
599
+ output_null_log_odds_file,
600
+ start_n_top,
601
+ end_n_top,
602
+ version_2_with_negative,
603
+ tokenizer,
604
+ verbose_logging,
605
+ ):
606
+ """
607
+ XLNet write prediction logic (more complex than Bert's). Write final predictions to the json file and log-odds of
608
+ null if needed.
609
+
610
+ Requires utils_squad_evaluate.py
611
+ """
612
+ _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
613
+ "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_log_prob", "end_log_prob"]
614
+ )
615
+
616
+ _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
617
+ "NbestPrediction", ["text", "start_log_prob", "end_log_prob"]
618
+ )
619
+
620
+ logger.info(f"Writing predictions to: {output_prediction_file}")
621
+
622
+ example_index_to_features = collections.defaultdict(list)
623
+ for feature in all_features:
624
+ example_index_to_features[feature.example_index].append(feature)
625
+
626
+ unique_id_to_result = {}
627
+ for result in all_results:
628
+ unique_id_to_result[result.unique_id] = result
629
+
630
+ all_predictions = collections.OrderedDict()
631
+ all_nbest_json = collections.OrderedDict()
632
+ scores_diff_json = collections.OrderedDict()
633
+
634
+ for example_index, example in enumerate(all_examples):
635
+ features = example_index_to_features[example_index]
636
+
637
+ prelim_predictions = []
638
+ # keep track of the minimum score of null start+end of position 0
639
+ score_null = 1000000 # large and positive
640
+
641
+ for feature_index, feature in enumerate(features):
642
+ result = unique_id_to_result[feature.unique_id]
643
+
644
+ cur_null_score = result.cls_logits
645
+
646
+ # if we could have irrelevant answers, get the min score of irrelevant
647
+ score_null = min(score_null, cur_null_score)
648
+
649
+ for i in range(start_n_top):
650
+ for j in range(end_n_top):
651
+ start_log_prob = result.start_logits[i]
652
+ start_index = result.start_top_index[i]
653
+
654
+ j_index = i * end_n_top + j
655
+
656
+ end_log_prob = result.end_logits[j_index]
657
+ end_index = result.end_top_index[j_index]
658
+
659
+ # We could hypothetically create invalid predictions, e.g., predict
660
+ # that the start of the span is in the question. We throw out all
661
+ # invalid predictions.
662
+ if start_index >= feature.paragraph_len - 1:
663
+ continue
664
+ if end_index >= feature.paragraph_len - 1:
665
+ continue
666
+
667
+ if not feature.token_is_max_context.get(start_index, False):
668
+ continue
669
+ if end_index < start_index:
670
+ continue
671
+ length = end_index - start_index + 1
672
+ if length > max_answer_length:
673
+ continue
674
+
675
+ prelim_predictions.append(
676
+ _PrelimPrediction(
677
+ feature_index=feature_index,
678
+ start_index=start_index,
679
+ end_index=end_index,
680
+ start_log_prob=start_log_prob,
681
+ end_log_prob=end_log_prob,
682
+ )
683
+ )
684
+
685
+ prelim_predictions = sorted(
686
+ prelim_predictions, key=lambda x: (x.start_log_prob + x.end_log_prob), reverse=True
687
+ )
688
+
689
+ seen_predictions = {}
690
+ nbest = []
691
+ for pred in prelim_predictions:
692
+ if len(nbest) >= n_best_size:
693
+ break
694
+ feature = features[pred.feature_index]
695
+
696
+ # XLNet un-tokenizer
697
+ # Let's keep it simple for now and see if we need all this later.
698
+ #
699
+ # tok_start_to_orig_index = feature.tok_start_to_orig_index
700
+ # tok_end_to_orig_index = feature.tok_end_to_orig_index
701
+ # start_orig_pos = tok_start_to_orig_index[pred.start_index]
702
+ # end_orig_pos = tok_end_to_orig_index[pred.end_index]
703
+ # paragraph_text = example.paragraph_text
704
+ # final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
705
+
706
+ # Previously used Bert untokenizer
707
+ tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
708
+ orig_doc_start = feature.token_to_orig_map[pred.start_index]
709
+ orig_doc_end = feature.token_to_orig_map[pred.end_index]
710
+ orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
711
+ tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
712
+
713
+ # Clean whitespace
714
+ tok_text = tok_text.strip()
715
+ tok_text = " ".join(tok_text.split())
716
+ orig_text = " ".join(orig_tokens)
717
+
718
+ if hasattr(tokenizer, "do_lower_case"):
719
+ do_lower_case = tokenizer.do_lower_case
720
+ else:
721
+ do_lower_case = tokenizer.do_lowercase_and_remove_accent
722
+
723
+ final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
724
+
725
+ if final_text in seen_predictions:
726
+ continue
727
+
728
+ seen_predictions[final_text] = True
729
+
730
+ nbest.append(
731
+ _NbestPrediction(text=final_text, start_log_prob=pred.start_log_prob, end_log_prob=pred.end_log_prob)
732
+ )
733
+
734
+ # In very rare edge cases we could have no valid predictions. So we
735
+ # just create a nonce prediction in this case to avoid failure.
736
+ if not nbest:
737
+ nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
738
+
739
+ total_scores = []
740
+ best_non_null_entry = None
741
+ for entry in nbest:
742
+ total_scores.append(entry.start_log_prob + entry.end_log_prob)
743
+ if not best_non_null_entry:
744
+ best_non_null_entry = entry
745
+
746
+ probs = _compute_softmax(total_scores)
747
+
748
+ nbest_json = []
749
+ for i, entry in enumerate(nbest):
750
+ output = collections.OrderedDict()
751
+ output["text"] = entry.text
752
+ output["probability"] = probs[i]
753
+ output["start_log_prob"] = entry.start_log_prob
754
+ output["end_log_prob"] = entry.end_log_prob
755
+ nbest_json.append(output)
756
+
757
+ if len(nbest_json) < 1:
758
+ raise ValueError("No valid predictions")
759
+ if best_non_null_entry is None:
760
+ raise ValueError("No valid predictions")
761
+
762
+ score_diff = score_null
763
+ scores_diff_json[example.qas_id] = score_diff
764
+ # note(zhiliny): always predict best_non_null_entry
765
+ # and the evaluation script will search for the best threshold
766
+ all_predictions[example.qas_id] = best_non_null_entry.text
767
+
768
+ all_nbest_json[example.qas_id] = nbest_json
769
+
770
+ with open(output_prediction_file, "w") as writer:
771
+ writer.write(json.dumps(all_predictions, indent=4) + "\n")
772
+
773
+ with open(output_nbest_file, "w") as writer:
774
+ writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
775
+
776
+ if version_2_with_negative:
777
+ with open(output_null_log_odds_file, "w") as writer:
778
+ writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
779
+
780
+ return all_predictions
transformers_4_35_0/data/processors/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .glue import glue_convert_examples_to_features, glue_output_modes, glue_processors, glue_tasks_num_labels
16
+ from .squad import SquadExample, SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
17
+ from .utils import DataProcessor, InputExample, InputFeatures, SingleSentenceClassificationProcessor
18
+ from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
transformers_4_35_0/data/processors/glue.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ GLUE processors and helpers"""
17
+
18
+ import os
19
+ import warnings
20
+ from dataclasses import asdict
21
+ from enum import Enum
22
+ from typing import List, Optional, Union
23
+
24
+ from ...tokenization_utils import PreTrainedTokenizer
25
+ from ...utils import is_tf_available, logging
26
+ from .utils import DataProcessor, InputExample, InputFeatures
27
+
28
+
29
+ if is_tf_available():
30
+ import tensorflow as tf
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ DEPRECATION_WARNING = (
35
+ "This {0} will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
36
+ "library. You can have a look at this example script for pointers: "
37
+ "https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
38
+ )
39
+
40
+
41
+ def glue_convert_examples_to_features(
42
+ examples: Union[List[InputExample], "tf.data.Dataset"],
43
+ tokenizer: PreTrainedTokenizer,
44
+ max_length: Optional[int] = None,
45
+ task=None,
46
+ label_list=None,
47
+ output_mode=None,
48
+ ):
49
+ """
50
+ Loads a data file into a list of `InputFeatures`
51
+
52
+ Args:
53
+ examples: List of `InputExamples` or `tf.data.Dataset` containing the examples.
54
+ tokenizer: Instance of a tokenizer that will tokenize the examples
55
+ max_length: Maximum example length. Defaults to the tokenizer's max_len
56
+ task: GLUE task
57
+ label_list: List of labels. Can be obtained from the processor using the `processor.get_labels()` method
58
+ output_mode: String indicating the output mode. Either `regression` or `classification`
59
+
60
+ Returns:
61
+ If the `examples` input is a `tf.data.Dataset`, will return a `tf.data.Dataset` containing the task-specific
62
+ features. If the input is a list of `InputExamples`, will return a list of task-specific `InputFeatures` which
63
+ can be fed to the model.
64
+
65
+ """
66
+ warnings.warn(DEPRECATION_WARNING.format("function"), FutureWarning)
67
+ if is_tf_available() and isinstance(examples, tf.data.Dataset):
68
+ if task is None:
69
+ raise ValueError("When calling glue_convert_examples_to_features from TF, the task parameter is required.")
70
+ return _tf_glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
71
+ return _glue_convert_examples_to_features(
72
+ examples, tokenizer, max_length=max_length, task=task, label_list=label_list, output_mode=output_mode
73
+ )
74
+
75
+
76
+ if is_tf_available():
77
+
78
+ def _tf_glue_convert_examples_to_features(
79
+ examples: tf.data.Dataset,
80
+ tokenizer: PreTrainedTokenizer,
81
+ task=str,
82
+ max_length: Optional[int] = None,
83
+ ) -> tf.data.Dataset:
84
+ """
85
+ Returns:
86
+ A `tf.data.Dataset` containing the task-specific features.
87
+
88
+ """
89
+ processor = glue_processors[task]()
90
+ examples = [processor.tfds_map(processor.get_example_from_tensor_dict(example)) for example in examples]
91
+ features = glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
92
+ label_type = tf.float32 if task == "sts-b" else tf.int64
93
+
94
+ def gen():
95
+ for ex in features:
96
+ d = {k: v for k, v in asdict(ex).items() if v is not None}
97
+ label = d.pop("label")
98
+ yield (d, label)
99
+
100
+ input_names = tokenizer.model_input_names
101
+
102
+ return tf.data.Dataset.from_generator(
103
+ gen,
104
+ ({k: tf.int32 for k in input_names}, label_type),
105
+ ({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])),
106
+ )
107
+
108
+
109
+ def _glue_convert_examples_to_features(
110
+ examples: List[InputExample],
111
+ tokenizer: PreTrainedTokenizer,
112
+ max_length: Optional[int] = None,
113
+ task=None,
114
+ label_list=None,
115
+ output_mode=None,
116
+ ):
117
+ if max_length is None:
118
+ max_length = tokenizer.model_max_length
119
+
120
+ if task is not None:
121
+ processor = glue_processors[task]()
122
+ if label_list is None:
123
+ label_list = processor.get_labels()
124
+ logger.info(f"Using label list {label_list} for task {task}")
125
+ if output_mode is None:
126
+ output_mode = glue_output_modes[task]
127
+ logger.info(f"Using output mode {output_mode} for task {task}")
128
+
129
+ label_map = {label: i for i, label in enumerate(label_list)}
130
+
131
+ def label_from_example(example: InputExample) -> Union[int, float, None]:
132
+ if example.label is None:
133
+ return None
134
+ if output_mode == "classification":
135
+ return label_map[example.label]
136
+ elif output_mode == "regression":
137
+ return float(example.label)
138
+ raise KeyError(output_mode)
139
+
140
+ labels = [label_from_example(example) for example in examples]
141
+
142
+ batch_encoding = tokenizer(
143
+ [(example.text_a, example.text_b) for example in examples],
144
+ max_length=max_length,
145
+ padding="max_length",
146
+ truncation=True,
147
+ )
148
+
149
+ features = []
150
+ for i in range(len(examples)):
151
+ inputs = {k: batch_encoding[k][i] for k in batch_encoding}
152
+
153
+ feature = InputFeatures(**inputs, label=labels[i])
154
+ features.append(feature)
155
+
156
+ for i, example in enumerate(examples[:5]):
157
+ logger.info("*** Example ***")
158
+ logger.info(f"guid: {example.guid}")
159
+ logger.info(f"features: {features[i]}")
160
+
161
+ return features
162
+
163
+
164
+ class OutputMode(Enum):
165
+ classification = "classification"
166
+ regression = "regression"
167
+
168
+
169
+ class MrpcProcessor(DataProcessor):
170
+ """Processor for the MRPC data set (GLUE version)."""
171
+
172
+ def __init__(self, *args, **kwargs):
173
+ super().__init__(*args, **kwargs)
174
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
175
+
176
+ def get_example_from_tensor_dict(self, tensor_dict):
177
+ """See base class."""
178
+ return InputExample(
179
+ tensor_dict["idx"].numpy(),
180
+ tensor_dict["sentence1"].numpy().decode("utf-8"),
181
+ tensor_dict["sentence2"].numpy().decode("utf-8"),
182
+ str(tensor_dict["label"].numpy()),
183
+ )
184
+
185
+ def get_train_examples(self, data_dir):
186
+ """See base class."""
187
+ logger.info(f"LOOKING AT {os.path.join(data_dir, 'train.tsv')}")
188
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
189
+
190
+ def get_dev_examples(self, data_dir):
191
+ """See base class."""
192
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
193
+
194
+ def get_test_examples(self, data_dir):
195
+ """See base class."""
196
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
197
+
198
+ def get_labels(self):
199
+ """See base class."""
200
+ return ["0", "1"]
201
+
202
+ def _create_examples(self, lines, set_type):
203
+ """Creates examples for the training, dev and test sets."""
204
+ examples = []
205
+ for i, line in enumerate(lines):
206
+ if i == 0:
207
+ continue
208
+ guid = f"{set_type}-{i}"
209
+ text_a = line[3]
210
+ text_b = line[4]
211
+ label = None if set_type == "test" else line[0]
212
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
213
+ return examples
214
+
215
+
216
+ class MnliProcessor(DataProcessor):
217
+ """Processor for the MultiNLI data set (GLUE version)."""
218
+
219
+ def __init__(self, *args, **kwargs):
220
+ super().__init__(*args, **kwargs)
221
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
222
+
223
+ def get_example_from_tensor_dict(self, tensor_dict):
224
+ """See base class."""
225
+ return InputExample(
226
+ tensor_dict["idx"].numpy(),
227
+ tensor_dict["premise"].numpy().decode("utf-8"),
228
+ tensor_dict["hypothesis"].numpy().decode("utf-8"),
229
+ str(tensor_dict["label"].numpy()),
230
+ )
231
+
232
+ def get_train_examples(self, data_dir):
233
+ """See base class."""
234
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
235
+
236
+ def get_dev_examples(self, data_dir):
237
+ """See base class."""
238
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
239
+
240
+ def get_test_examples(self, data_dir):
241
+ """See base class."""
242
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched")
243
+
244
+ def get_labels(self):
245
+ """See base class."""
246
+ return ["contradiction", "entailment", "neutral"]
247
+
248
+ def _create_examples(self, lines, set_type):
249
+ """Creates examples for the training, dev and test sets."""
250
+ examples = []
251
+ for i, line in enumerate(lines):
252
+ if i == 0:
253
+ continue
254
+ guid = f"{set_type}-{line[0]}"
255
+ text_a = line[8]
256
+ text_b = line[9]
257
+ label = None if set_type.startswith("test") else line[-1]
258
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
259
+ return examples
260
+
261
+
262
+ class MnliMismatchedProcessor(MnliProcessor):
263
+ """Processor for the MultiNLI Mismatched data set (GLUE version)."""
264
+
265
+ def __init__(self, *args, **kwargs):
266
+ super().__init__(*args, **kwargs)
267
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
268
+
269
+ def get_dev_examples(self, data_dir):
270
+ """See base class."""
271
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")
272
+
273
+ def get_test_examples(self, data_dir):
274
+ """See base class."""
275
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched")
276
+
277
+
278
+ class ColaProcessor(DataProcessor):
279
+ """Processor for the CoLA data set (GLUE version)."""
280
+
281
+ def __init__(self, *args, **kwargs):
282
+ super().__init__(*args, **kwargs)
283
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
284
+
285
+ def get_example_from_tensor_dict(self, tensor_dict):
286
+ """See base class."""
287
+ return InputExample(
288
+ tensor_dict["idx"].numpy(),
289
+ tensor_dict["sentence"].numpy().decode("utf-8"),
290
+ None,
291
+ str(tensor_dict["label"].numpy()),
292
+ )
293
+
294
+ def get_train_examples(self, data_dir):
295
+ """See base class."""
296
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
297
+
298
+ def get_dev_examples(self, data_dir):
299
+ """See base class."""
300
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
301
+
302
+ def get_test_examples(self, data_dir):
303
+ """See base class."""
304
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
305
+
306
+ def get_labels(self):
307
+ """See base class."""
308
+ return ["0", "1"]
309
+
310
+ def _create_examples(self, lines, set_type):
311
+ """Creates examples for the training, dev and test sets."""
312
+ test_mode = set_type == "test"
313
+ if test_mode:
314
+ lines = lines[1:]
315
+ text_index = 1 if test_mode else 3
316
+ examples = []
317
+ for i, line in enumerate(lines):
318
+ guid = f"{set_type}-{i}"
319
+ text_a = line[text_index]
320
+ label = None if test_mode else line[1]
321
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
322
+ return examples
323
+
324
+
325
+ class Sst2Processor(DataProcessor):
326
+ """Processor for the SST-2 data set (GLUE version)."""
327
+
328
+ def __init__(self, *args, **kwargs):
329
+ super().__init__(*args, **kwargs)
330
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
331
+
332
+ def get_example_from_tensor_dict(self, tensor_dict):
333
+ """See base class."""
334
+ return InputExample(
335
+ tensor_dict["idx"].numpy(),
336
+ tensor_dict["sentence"].numpy().decode("utf-8"),
337
+ None,
338
+ str(tensor_dict["label"].numpy()),
339
+ )
340
+
341
+ def get_train_examples(self, data_dir):
342
+ """See base class."""
343
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
344
+
345
+ def get_dev_examples(self, data_dir):
346
+ """See base class."""
347
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
348
+
349
+ def get_test_examples(self, data_dir):
350
+ """See base class."""
351
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
352
+
353
+ def get_labels(self):
354
+ """See base class."""
355
+ return ["0", "1"]
356
+
357
+ def _create_examples(self, lines, set_type):
358
+ """Creates examples for the training, dev and test sets."""
359
+ examples = []
360
+ text_index = 1 if set_type == "test" else 0
361
+ for i, line in enumerate(lines):
362
+ if i == 0:
363
+ continue
364
+ guid = f"{set_type}-{i}"
365
+ text_a = line[text_index]
366
+ label = None if set_type == "test" else line[1]
367
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
368
+ return examples
369
+
370
+
371
+ class StsbProcessor(DataProcessor):
372
+ """Processor for the STS-B data set (GLUE version)."""
373
+
374
+ def __init__(self, *args, **kwargs):
375
+ super().__init__(*args, **kwargs)
376
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
377
+
378
+ def get_example_from_tensor_dict(self, tensor_dict):
379
+ """See base class."""
380
+ return InputExample(
381
+ tensor_dict["idx"].numpy(),
382
+ tensor_dict["sentence1"].numpy().decode("utf-8"),
383
+ tensor_dict["sentence2"].numpy().decode("utf-8"),
384
+ str(tensor_dict["label"].numpy()),
385
+ )
386
+
387
+ def get_train_examples(self, data_dir):
388
+ """See base class."""
389
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
390
+
391
+ def get_dev_examples(self, data_dir):
392
+ """See base class."""
393
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
394
+
395
+ def get_test_examples(self, data_dir):
396
+ """See base class."""
397
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
398
+
399
+ def get_labels(self):
400
+ """See base class."""
401
+ return [None]
402
+
403
+ def _create_examples(self, lines, set_type):
404
+ """Creates examples for the training, dev and test sets."""
405
+ examples = []
406
+ for i, line in enumerate(lines):
407
+ if i == 0:
408
+ continue
409
+ guid = f"{set_type}-{line[0]}"
410
+ text_a = line[7]
411
+ text_b = line[8]
412
+ label = None if set_type == "test" else line[-1]
413
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
414
+ return examples
415
+
416
+
417
+ class QqpProcessor(DataProcessor):
418
+ """Processor for the QQP data set (GLUE version)."""
419
+
420
+ def __init__(self, *args, **kwargs):
421
+ super().__init__(*args, **kwargs)
422
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
423
+
424
+ def get_example_from_tensor_dict(self, tensor_dict):
425
+ """See base class."""
426
+ return InputExample(
427
+ tensor_dict["idx"].numpy(),
428
+ tensor_dict["question1"].numpy().decode("utf-8"),
429
+ tensor_dict["question2"].numpy().decode("utf-8"),
430
+ str(tensor_dict["label"].numpy()),
431
+ )
432
+
433
+ def get_train_examples(self, data_dir):
434
+ """See base class."""
435
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
436
+
437
+ def get_dev_examples(self, data_dir):
438
+ """See base class."""
439
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
440
+
441
+ def get_test_examples(self, data_dir):
442
+ """See base class."""
443
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
444
+
445
+ def get_labels(self):
446
+ """See base class."""
447
+ return ["0", "1"]
448
+
449
+ def _create_examples(self, lines, set_type):
450
+ """Creates examples for the training, dev and test sets."""
451
+ test_mode = set_type == "test"
452
+ q1_index = 1 if test_mode else 3
453
+ q2_index = 2 if test_mode else 4
454
+ examples = []
455
+ for i, line in enumerate(lines):
456
+ if i == 0:
457
+ continue
458
+ guid = f"{set_type}-{line[0]}"
459
+ try:
460
+ text_a = line[q1_index]
461
+ text_b = line[q2_index]
462
+ label = None if test_mode else line[5]
463
+ except IndexError:
464
+ continue
465
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
466
+ return examples
467
+
468
+
469
+ class QnliProcessor(DataProcessor):
470
+ """Processor for the QNLI data set (GLUE version)."""
471
+
472
+ def __init__(self, *args, **kwargs):
473
+ super().__init__(*args, **kwargs)
474
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
475
+
476
+ def get_example_from_tensor_dict(self, tensor_dict):
477
+ """See base class."""
478
+ return InputExample(
479
+ tensor_dict["idx"].numpy(),
480
+ tensor_dict["question"].numpy().decode("utf-8"),
481
+ tensor_dict["sentence"].numpy().decode("utf-8"),
482
+ str(tensor_dict["label"].numpy()),
483
+ )
484
+
485
+ def get_train_examples(self, data_dir):
486
+ """See base class."""
487
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
488
+
489
+ def get_dev_examples(self, data_dir):
490
+ """See base class."""
491
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
492
+
493
+ def get_test_examples(self, data_dir):
494
+ """See base class."""
495
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
496
+
497
+ def get_labels(self):
498
+ """See base class."""
499
+ return ["entailment", "not_entailment"]
500
+
501
+ def _create_examples(self, lines, set_type):
502
+ """Creates examples for the training, dev and test sets."""
503
+ examples = []
504
+ for i, line in enumerate(lines):
505
+ if i == 0:
506
+ continue
507
+ guid = f"{set_type}-{line[0]}"
508
+ text_a = line[1]
509
+ text_b = line[2]
510
+ label = None if set_type == "test" else line[-1]
511
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
512
+ return examples
513
+
514
+
515
+ class RteProcessor(DataProcessor):
516
+ """Processor for the RTE data set (GLUE version)."""
517
+
518
+ def __init__(self, *args, **kwargs):
519
+ super().__init__(*args, **kwargs)
520
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
521
+
522
+ def get_example_from_tensor_dict(self, tensor_dict):
523
+ """See base class."""
524
+ return InputExample(
525
+ tensor_dict["idx"].numpy(),
526
+ tensor_dict["sentence1"].numpy().decode("utf-8"),
527
+ tensor_dict["sentence2"].numpy().decode("utf-8"),
528
+ str(tensor_dict["label"].numpy()),
529
+ )
530
+
531
+ def get_train_examples(self, data_dir):
532
+ """See base class."""
533
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
534
+
535
+ def get_dev_examples(self, data_dir):
536
+ """See base class."""
537
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
538
+
539
+ def get_test_examples(self, data_dir):
540
+ """See base class."""
541
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
542
+
543
+ def get_labels(self):
544
+ """See base class."""
545
+ return ["entailment", "not_entailment"]
546
+
547
+ def _create_examples(self, lines, set_type):
548
+ """Creates examples for the training, dev and test sets."""
549
+ examples = []
550
+ for i, line in enumerate(lines):
551
+ if i == 0:
552
+ continue
553
+ guid = f"{set_type}-{line[0]}"
554
+ text_a = line[1]
555
+ text_b = line[2]
556
+ label = None if set_type == "test" else line[-1]
557
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
558
+ return examples
559
+
560
+
561
+ class WnliProcessor(DataProcessor):
562
+ """Processor for the WNLI data set (GLUE version)."""
563
+
564
+ def __init__(self, *args, **kwargs):
565
+ super().__init__(*args, **kwargs)
566
+ warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
567
+
568
+ def get_example_from_tensor_dict(self, tensor_dict):
569
+ """See base class."""
570
+ return InputExample(
571
+ tensor_dict["idx"].numpy(),
572
+ tensor_dict["sentence1"].numpy().decode("utf-8"),
573
+ tensor_dict["sentence2"].numpy().decode("utf-8"),
574
+ str(tensor_dict["label"].numpy()),
575
+ )
576
+
577
+ def get_train_examples(self, data_dir):
578
+ """See base class."""
579
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
580
+
581
+ def get_dev_examples(self, data_dir):
582
+ """See base class."""
583
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
584
+
585
+ def get_test_examples(self, data_dir):
586
+ """See base class."""
587
+ return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
588
+
589
+ def get_labels(self):
590
+ """See base class."""
591
+ return ["0", "1"]
592
+
593
+ def _create_examples(self, lines, set_type):
594
+ """Creates examples for the training, dev and test sets."""
595
+ examples = []
596
+ for i, line in enumerate(lines):
597
+ if i == 0:
598
+ continue
599
+ guid = f"{set_type}-{line[0]}"
600
+ text_a = line[1]
601
+ text_b = line[2]
602
+ label = None if set_type == "test" else line[-1]
603
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
604
+ return examples
605
+
606
+
607
+ glue_tasks_num_labels = {
608
+ "cola": 2,
609
+ "mnli": 3,
610
+ "mrpc": 2,
611
+ "sst-2": 2,
612
+ "sts-b": 1,
613
+ "qqp": 2,
614
+ "qnli": 2,
615
+ "rte": 2,
616
+ "wnli": 2,
617
+ }
618
+
619
+ glue_processors = {
620
+ "cola": ColaProcessor,
621
+ "mnli": MnliProcessor,
622
+ "mnli-mm": MnliMismatchedProcessor,
623
+ "mrpc": MrpcProcessor,
624
+ "sst-2": Sst2Processor,
625
+ "sts-b": StsbProcessor,
626
+ "qqp": QqpProcessor,
627
+ "qnli": QnliProcessor,
628
+ "rte": RteProcessor,
629
+ "wnli": WnliProcessor,
630
+ }
631
+
632
+ glue_output_modes = {
633
+ "cola": "classification",
634
+ "mnli": "classification",
635
+ "mnli-mm": "classification",
636
+ "mrpc": "classification",
637
+ "sst-2": "classification",
638
+ "sts-b": "regression",
639
+ "qqp": "classification",
640
+ "qnli": "classification",
641
+ "rte": "classification",
642
+ "wnli": "classification",
643
+ }
transformers_4_35_0/data/processors/squad.py ADDED
@@ -0,0 +1,845 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ from functools import partial
18
+ from multiprocessing import Pool, cpu_count
19
+
20
+ import numpy as np
21
+ from tqdm import tqdm
22
+
23
+ from ...models.bert.tokenization_bert import whitespace_tokenize
24
+ from ...tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TruncationStrategy
25
+ from ...utils import is_tf_available, is_torch_available, logging
26
+ from .utils import DataProcessor
27
+
28
+
29
+ # Store the tokenizers which insert 2 separators tokens
30
+ MULTI_SEP_TOKENS_TOKENIZERS_SET = {"roberta", "camembert", "bart", "mpnet"}
31
+
32
+
33
+ if is_torch_available():
34
+ import torch
35
+ from torch.utils.data import TensorDataset
36
+
37
+ if is_tf_available():
38
+ import tensorflow as tf
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+
43
+ def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
44
+ """Returns tokenized answer spans that better match the annotated answer."""
45
+ tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
46
+
47
+ for new_start in range(input_start, input_end + 1):
48
+ for new_end in range(input_end, new_start - 1, -1):
49
+ text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
50
+ if text_span == tok_answer_text:
51
+ return (new_start, new_end)
52
+
53
+ return (input_start, input_end)
54
+
55
+
56
+ def _check_is_max_context(doc_spans, cur_span_index, position):
57
+ """Check if this is the 'max context' doc span for the token."""
58
+ best_score = None
59
+ best_span_index = None
60
+ for span_index, doc_span in enumerate(doc_spans):
61
+ end = doc_span.start + doc_span.length - 1
62
+ if position < doc_span.start:
63
+ continue
64
+ if position > end:
65
+ continue
66
+ num_left_context = position - doc_span.start
67
+ num_right_context = end - position
68
+ score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
69
+ if best_score is None or score > best_score:
70
+ best_score = score
71
+ best_span_index = span_index
72
+
73
+ return cur_span_index == best_span_index
74
+
75
+
76
+ def _new_check_is_max_context(doc_spans, cur_span_index, position):
77
+ """Check if this is the 'max context' doc span for the token."""
78
+ # if len(doc_spans) == 1:
79
+ # return True
80
+ best_score = None
81
+ best_span_index = None
82
+ for span_index, doc_span in enumerate(doc_spans):
83
+ end = doc_span["start"] + doc_span["length"] - 1
84
+ if position < doc_span["start"]:
85
+ continue
86
+ if position > end:
87
+ continue
88
+ num_left_context = position - doc_span["start"]
89
+ num_right_context = end - position
90
+ score = min(num_left_context, num_right_context) + 0.01 * doc_span["length"]
91
+ if best_score is None or score > best_score:
92
+ best_score = score
93
+ best_span_index = span_index
94
+
95
+ return cur_span_index == best_span_index
96
+
97
+
98
+ def _is_whitespace(c):
99
+ if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
100
+ return True
101
+ return False
102
+
103
+
104
+ def squad_convert_example_to_features(
105
+ example, max_seq_length, doc_stride, max_query_length, padding_strategy, is_training
106
+ ):
107
+ features = []
108
+ if is_training and not example.is_impossible:
109
+ # Get start and end position
110
+ start_position = example.start_position
111
+ end_position = example.end_position
112
+
113
+ # If the answer cannot be found in the text, then skip this example.
114
+ actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)])
115
+ cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
116
+ if actual_text.find(cleaned_answer_text) == -1:
117
+ logger.warning(f"Could not find answer: '{actual_text}' vs. '{cleaned_answer_text}'")
118
+ return []
119
+
120
+ tok_to_orig_index = []
121
+ orig_to_tok_index = []
122
+ all_doc_tokens = []
123
+ for i, token in enumerate(example.doc_tokens):
124
+ orig_to_tok_index.append(len(all_doc_tokens))
125
+ if tokenizer.__class__.__name__ in [
126
+ "RobertaTokenizer",
127
+ "LongformerTokenizer",
128
+ "BartTokenizer",
129
+ "RobertaTokenizerFast",
130
+ "LongformerTokenizerFast",
131
+ "BartTokenizerFast",
132
+ ]:
133
+ sub_tokens = tokenizer.tokenize(token, add_prefix_space=True)
134
+ else:
135
+ sub_tokens = tokenizer.tokenize(token)
136
+ for sub_token in sub_tokens:
137
+ tok_to_orig_index.append(i)
138
+ all_doc_tokens.append(sub_token)
139
+
140
+ if is_training and not example.is_impossible:
141
+ tok_start_position = orig_to_tok_index[example.start_position]
142
+ if example.end_position < len(example.doc_tokens) - 1:
143
+ tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
144
+ else:
145
+ tok_end_position = len(all_doc_tokens) - 1
146
+
147
+ (tok_start_position, tok_end_position) = _improve_answer_span(
148
+ all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text
149
+ )
150
+
151
+ spans = []
152
+
153
+ truncated_query = tokenizer.encode(
154
+ example.question_text, add_special_tokens=False, truncation=True, max_length=max_query_length
155
+ )
156
+
157
+ # Tokenizers who insert 2 SEP tokens in-between <context> & <question> need to have special handling
158
+ # in the way they compute mask of added tokens.
159
+ tokenizer_type = type(tokenizer).__name__.replace("Tokenizer", "").lower()
160
+ sequence_added_tokens = (
161
+ tokenizer.model_max_length - tokenizer.max_len_single_sentence + 1
162
+ if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET
163
+ else tokenizer.model_max_length - tokenizer.max_len_single_sentence
164
+ )
165
+ sequence_pair_added_tokens = tokenizer.model_max_length - tokenizer.max_len_sentences_pair
166
+
167
+ span_doc_tokens = all_doc_tokens
168
+ while len(spans) * doc_stride < len(all_doc_tokens):
169
+ # Define the side we want to truncate / pad and the text/pair sorting
170
+ if tokenizer.padding_side == "right":
171
+ texts = truncated_query
172
+ pairs = span_doc_tokens
173
+ truncation = TruncationStrategy.ONLY_SECOND.value
174
+ else:
175
+ texts = span_doc_tokens
176
+ pairs = truncated_query
177
+ truncation = TruncationStrategy.ONLY_FIRST.value
178
+
179
+ encoded_dict = tokenizer.encode_plus( # TODO(thom) update this logic
180
+ texts,
181
+ pairs,
182
+ truncation=truncation,
183
+ padding=padding_strategy,
184
+ max_length=max_seq_length,
185
+ return_overflowing_tokens=True,
186
+ stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
187
+ return_token_type_ids=True,
188
+ )
189
+
190
+ paragraph_len = min(
191
+ len(all_doc_tokens) - len(spans) * doc_stride,
192
+ max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
193
+ )
194
+
195
+ if tokenizer.pad_token_id in encoded_dict["input_ids"]:
196
+ if tokenizer.padding_side == "right":
197
+ non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
198
+ else:
199
+ last_padding_id_position = (
200
+ len(encoded_dict["input_ids"]) - 1 - encoded_dict["input_ids"][::-1].index(tokenizer.pad_token_id)
201
+ )
202
+ non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1 :]
203
+
204
+ else:
205
+ non_padded_ids = encoded_dict["input_ids"]
206
+
207
+ tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
208
+
209
+ token_to_orig_map = {}
210
+ for i in range(paragraph_len):
211
+ index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i
212
+ token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
213
+
214
+ encoded_dict["paragraph_len"] = paragraph_len
215
+ encoded_dict["tokens"] = tokens
216
+ encoded_dict["token_to_orig_map"] = token_to_orig_map
217
+ encoded_dict["truncated_query_with_special_tokens_length"] = len(truncated_query) + sequence_added_tokens
218
+ encoded_dict["token_is_max_context"] = {}
219
+ encoded_dict["start"] = len(spans) * doc_stride
220
+ encoded_dict["length"] = paragraph_len
221
+
222
+ spans.append(encoded_dict)
223
+
224
+ if "overflowing_tokens" not in encoded_dict or (
225
+ "overflowing_tokens" in encoded_dict and len(encoded_dict["overflowing_tokens"]) == 0
226
+ ):
227
+ break
228
+ span_doc_tokens = encoded_dict["overflowing_tokens"]
229
+
230
+ for doc_span_index in range(len(spans)):
231
+ for j in range(spans[doc_span_index]["paragraph_len"]):
232
+ is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
233
+ index = (
234
+ j
235
+ if tokenizer.padding_side == "left"
236
+ else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
237
+ )
238
+ spans[doc_span_index]["token_is_max_context"][index] = is_max_context
239
+
240
+ for span in spans:
241
+ # Identify the position of the CLS token
242
+ cls_index = span["input_ids"].index(tokenizer.cls_token_id)
243
+
244
+ # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
245
+ # Original TF implementation also keep the classification token (set to 0)
246
+ p_mask = np.ones_like(span["token_type_ids"])
247
+ if tokenizer.padding_side == "right":
248
+ p_mask[len(truncated_query) + sequence_added_tokens :] = 0
249
+ else:
250
+ p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0
251
+
252
+ pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id)
253
+ special_token_indices = np.asarray(
254
+ tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True)
255
+ ).nonzero()
256
+
257
+ p_mask[pad_token_indices] = 1
258
+ p_mask[special_token_indices] = 1
259
+
260
+ # Set the cls index to 0: the CLS index can be used for impossible answers
261
+ p_mask[cls_index] = 0
262
+
263
+ span_is_impossible = example.is_impossible
264
+ start_position = 0
265
+ end_position = 0
266
+ if is_training and not span_is_impossible:
267
+ # For training, if our document chunk does not contain an annotation
268
+ # we throw it out, since there is nothing to predict.
269
+ doc_start = span["start"]
270
+ doc_end = span["start"] + span["length"] - 1
271
+ out_of_span = False
272
+
273
+ if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
274
+ out_of_span = True
275
+
276
+ if out_of_span:
277
+ start_position = cls_index
278
+ end_position = cls_index
279
+ span_is_impossible = True
280
+ else:
281
+ if tokenizer.padding_side == "left":
282
+ doc_offset = 0
283
+ else:
284
+ doc_offset = len(truncated_query) + sequence_added_tokens
285
+
286
+ start_position = tok_start_position - doc_start + doc_offset
287
+ end_position = tok_end_position - doc_start + doc_offset
288
+
289
+ features.append(
290
+ SquadFeatures(
291
+ span["input_ids"],
292
+ span["attention_mask"],
293
+ span["token_type_ids"],
294
+ cls_index,
295
+ p_mask.tolist(),
296
+ example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing.
297
+ unique_id=0,
298
+ paragraph_len=span["paragraph_len"],
299
+ token_is_max_context=span["token_is_max_context"],
300
+ tokens=span["tokens"],
301
+ token_to_orig_map=span["token_to_orig_map"],
302
+ start_position=start_position,
303
+ end_position=end_position,
304
+ is_impossible=span_is_impossible,
305
+ qas_id=example.qas_id,
306
+ )
307
+ )
308
+ return features
309
+
310
+
311
+ def squad_convert_example_to_features_init(tokenizer_for_convert: PreTrainedTokenizerBase):
312
+ global tokenizer
313
+ tokenizer = tokenizer_for_convert
314
+
315
+
316
+ def squad_convert_examples_to_features(
317
+ examples,
318
+ tokenizer,
319
+ max_seq_length,
320
+ doc_stride,
321
+ max_query_length,
322
+ is_training,
323
+ padding_strategy="max_length",
324
+ return_dataset=False,
325
+ threads=1,
326
+ tqdm_enabled=True,
327
+ ):
328
+ """
329
+ Converts a list of examples into a list of features that can be directly given as input to a model. It is
330
+ model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
331
+
332
+ Args:
333
+ examples: list of [`~data.processors.squad.SquadExample`]
334
+ tokenizer: an instance of a child of [`PreTrainedTokenizer`]
335
+ max_seq_length: The maximum sequence length of the inputs.
336
+ doc_stride: The stride used when the context is too large and is split across several features.
337
+ max_query_length: The maximum length of the query.
338
+ is_training: whether to create features for model evaluation or model training.
339
+ padding_strategy: Default to "max_length". Which padding strategy to use
340
+ return_dataset: Default False. Either 'pt' or 'tf'.
341
+ if 'pt': returns a torch.data.TensorDataset, if 'tf': returns a tf.data.Dataset
342
+ threads: multiple processing threads.
343
+
344
+
345
+ Returns:
346
+ list of [`~data.processors.squad.SquadFeatures`]
347
+
348
+ Example:
349
+
350
+ ```python
351
+ processor = SquadV2Processor()
352
+ examples = processor.get_dev_examples(data_dir)
353
+
354
+ features = squad_convert_examples_to_features(
355
+ examples=examples,
356
+ tokenizer=tokenizer,
357
+ max_seq_length=args.max_seq_length,
358
+ doc_stride=args.doc_stride,
359
+ max_query_length=args.max_query_length,
360
+ is_training=not evaluate,
361
+ )
362
+ ```"""
363
+ # Defining helper methods
364
+ features = []
365
+
366
+ threads = min(threads, cpu_count())
367
+ with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p:
368
+ annotate_ = partial(
369
+ squad_convert_example_to_features,
370
+ max_seq_length=max_seq_length,
371
+ doc_stride=doc_stride,
372
+ max_query_length=max_query_length,
373
+ padding_strategy=padding_strategy,
374
+ is_training=is_training,
375
+ )
376
+ features = list(
377
+ tqdm(
378
+ p.imap(annotate_, examples, chunksize=32),
379
+ total=len(examples),
380
+ desc="convert squad examples to features",
381
+ disable=not tqdm_enabled,
382
+ )
383
+ )
384
+
385
+ new_features = []
386
+ unique_id = 1000000000
387
+ example_index = 0
388
+ for example_features in tqdm(
389
+ features, total=len(features), desc="add example index and unique id", disable=not tqdm_enabled
390
+ ):
391
+ if not example_features:
392
+ continue
393
+ for example_feature in example_features:
394
+ example_feature.example_index = example_index
395
+ example_feature.unique_id = unique_id
396
+ new_features.append(example_feature)
397
+ unique_id += 1
398
+ example_index += 1
399
+ features = new_features
400
+ del new_features
401
+ if return_dataset == "pt":
402
+ if not is_torch_available():
403
+ raise RuntimeError("PyTorch must be installed to return a PyTorch dataset.")
404
+
405
+ # Convert to Tensors and build dataset
406
+ all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
407
+ all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
408
+ all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
409
+ all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
410
+ all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
411
+ all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float)
412
+
413
+ if not is_training:
414
+ all_feature_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
415
+ dataset = TensorDataset(
416
+ all_input_ids, all_attention_masks, all_token_type_ids, all_feature_index, all_cls_index, all_p_mask
417
+ )
418
+ else:
419
+ all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
420
+ all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
421
+ dataset = TensorDataset(
422
+ all_input_ids,
423
+ all_attention_masks,
424
+ all_token_type_ids,
425
+ all_start_positions,
426
+ all_end_positions,
427
+ all_cls_index,
428
+ all_p_mask,
429
+ all_is_impossible,
430
+ )
431
+
432
+ return features, dataset
433
+ elif return_dataset == "tf":
434
+ if not is_tf_available():
435
+ raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.")
436
+
437
+ def gen():
438
+ for i, ex in enumerate(features):
439
+ if ex.token_type_ids is None:
440
+ yield (
441
+ {
442
+ "input_ids": ex.input_ids,
443
+ "attention_mask": ex.attention_mask,
444
+ "feature_index": i,
445
+ "qas_id": ex.qas_id,
446
+ },
447
+ {
448
+ "start_positions": ex.start_position,
449
+ "end_positions": ex.end_position,
450
+ "cls_index": ex.cls_index,
451
+ "p_mask": ex.p_mask,
452
+ "is_impossible": ex.is_impossible,
453
+ },
454
+ )
455
+ else:
456
+ yield (
457
+ {
458
+ "input_ids": ex.input_ids,
459
+ "attention_mask": ex.attention_mask,
460
+ "token_type_ids": ex.token_type_ids,
461
+ "feature_index": i,
462
+ "qas_id": ex.qas_id,
463
+ },
464
+ {
465
+ "start_positions": ex.start_position,
466
+ "end_positions": ex.end_position,
467
+ "cls_index": ex.cls_index,
468
+ "p_mask": ex.p_mask,
469
+ "is_impossible": ex.is_impossible,
470
+ },
471
+ )
472
+
473
+ # Why have we split the batch into a tuple? PyTorch just has a list of tensors.
474
+ if "token_type_ids" in tokenizer.model_input_names:
475
+ train_types = (
476
+ {
477
+ "input_ids": tf.int32,
478
+ "attention_mask": tf.int32,
479
+ "token_type_ids": tf.int32,
480
+ "feature_index": tf.int64,
481
+ "qas_id": tf.string,
482
+ },
483
+ {
484
+ "start_positions": tf.int64,
485
+ "end_positions": tf.int64,
486
+ "cls_index": tf.int64,
487
+ "p_mask": tf.int32,
488
+ "is_impossible": tf.int32,
489
+ },
490
+ )
491
+
492
+ train_shapes = (
493
+ {
494
+ "input_ids": tf.TensorShape([None]),
495
+ "attention_mask": tf.TensorShape([None]),
496
+ "token_type_ids": tf.TensorShape([None]),
497
+ "feature_index": tf.TensorShape([]),
498
+ "qas_id": tf.TensorShape([]),
499
+ },
500
+ {
501
+ "start_positions": tf.TensorShape([]),
502
+ "end_positions": tf.TensorShape([]),
503
+ "cls_index": tf.TensorShape([]),
504
+ "p_mask": tf.TensorShape([None]),
505
+ "is_impossible": tf.TensorShape([]),
506
+ },
507
+ )
508
+ else:
509
+ train_types = (
510
+ {"input_ids": tf.int32, "attention_mask": tf.int32, "feature_index": tf.int64, "qas_id": tf.string},
511
+ {
512
+ "start_positions": tf.int64,
513
+ "end_positions": tf.int64,
514
+ "cls_index": tf.int64,
515
+ "p_mask": tf.int32,
516
+ "is_impossible": tf.int32,
517
+ },
518
+ )
519
+
520
+ train_shapes = (
521
+ {
522
+ "input_ids": tf.TensorShape([None]),
523
+ "attention_mask": tf.TensorShape([None]),
524
+ "feature_index": tf.TensorShape([]),
525
+ "qas_id": tf.TensorShape([]),
526
+ },
527
+ {
528
+ "start_positions": tf.TensorShape([]),
529
+ "end_positions": tf.TensorShape([]),
530
+ "cls_index": tf.TensorShape([]),
531
+ "p_mask": tf.TensorShape([None]),
532
+ "is_impossible": tf.TensorShape([]),
533
+ },
534
+ )
535
+
536
+ return tf.data.Dataset.from_generator(gen, train_types, train_shapes)
537
+ else:
538
+ return features
539
+
540
+
541
+ class SquadProcessor(DataProcessor):
542
+ """
543
+ Processor for the SQuAD data set. overridden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and
544
+ version 2.0 of SQuAD, respectively.
545
+ """
546
+
547
+ train_file = None
548
+ dev_file = None
549
+
550
+ def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False):
551
+ if not evaluate:
552
+ answer = tensor_dict["answers"]["text"][0].numpy().decode("utf-8")
553
+ answer_start = tensor_dict["answers"]["answer_start"][0].numpy()
554
+ answers = []
555
+ else:
556
+ answers = [
557
+ {"answer_start": start.numpy(), "text": text.numpy().decode("utf-8")}
558
+ for start, text in zip(tensor_dict["answers"]["answer_start"], tensor_dict["answers"]["text"])
559
+ ]
560
+
561
+ answer = None
562
+ answer_start = None
563
+
564
+ return SquadExample(
565
+ qas_id=tensor_dict["id"].numpy().decode("utf-8"),
566
+ question_text=tensor_dict["question"].numpy().decode("utf-8"),
567
+ context_text=tensor_dict["context"].numpy().decode("utf-8"),
568
+ answer_text=answer,
569
+ start_position_character=answer_start,
570
+ title=tensor_dict["title"].numpy().decode("utf-8"),
571
+ answers=answers,
572
+ )
573
+
574
+ def get_examples_from_dataset(self, dataset, evaluate=False):
575
+ """
576
+ Creates a list of [`~data.processors.squad.SquadExample`] using a TFDS dataset.
577
+
578
+ Args:
579
+ dataset: The tfds dataset loaded from *tensorflow_datasets.load("squad")*
580
+ evaluate: Boolean specifying if in evaluation mode or in training mode
581
+
582
+ Returns:
583
+ List of SquadExample
584
+
585
+ Examples:
586
+
587
+ ```python
588
+ >>> import tensorflow_datasets as tfds
589
+
590
+ >>> dataset = tfds.load("squad")
591
+
592
+ >>> training_examples = get_examples_from_dataset(dataset, evaluate=False)
593
+ >>> evaluation_examples = get_examples_from_dataset(dataset, evaluate=True)
594
+ ```"""
595
+
596
+ if evaluate:
597
+ dataset = dataset["validation"]
598
+ else:
599
+ dataset = dataset["train"]
600
+
601
+ examples = []
602
+ for tensor_dict in tqdm(dataset):
603
+ examples.append(self._get_example_from_tensor_dict(tensor_dict, evaluate=evaluate))
604
+
605
+ return examples
606
+
607
+ def get_train_examples(self, data_dir, filename=None):
608
+ """
609
+ Returns the training examples from the data directory.
610
+
611
+ Args:
612
+ data_dir: Directory containing the data files used for training and evaluating.
613
+ filename: None by default, specify this if the training file has a different name than the original one
614
+ which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
615
+
616
+ """
617
+ if data_dir is None:
618
+ data_dir = ""
619
+
620
+ if self.train_file is None:
621
+ raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
622
+
623
+ with open(
624
+ os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8"
625
+ ) as reader:
626
+ input_data = json.load(reader)["data"]
627
+ return self._create_examples(input_data, "train")
628
+
629
+ def get_dev_examples(self, data_dir, filename=None):
630
+ """
631
+ Returns the evaluation example from the data directory.
632
+
633
+ Args:
634
+ data_dir: Directory containing the data files used for training and evaluating.
635
+ filename: None by default, specify this if the evaluation file has a different name than the original one
636
+ which is `dev-v1.1.json` and `dev-v2.0.json` for squad versions 1.1 and 2.0 respectively.
637
+ """
638
+ if data_dir is None:
639
+ data_dir = ""
640
+
641
+ if self.dev_file is None:
642
+ raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
643
+
644
+ with open(
645
+ os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8"
646
+ ) as reader:
647
+ input_data = json.load(reader)["data"]
648
+ return self._create_examples(input_data, "dev")
649
+
650
+ def _create_examples(self, input_data, set_type):
651
+ is_training = set_type == "train"
652
+ examples = []
653
+ for entry in tqdm(input_data):
654
+ title = entry["title"]
655
+ for paragraph in entry["paragraphs"]:
656
+ context_text = paragraph["context"]
657
+ for qa in paragraph["qas"]:
658
+ qas_id = qa["id"]
659
+ question_text = qa["question"]
660
+ start_position_character = None
661
+ answer_text = None
662
+ answers = []
663
+
664
+ is_impossible = qa.get("is_impossible", False)
665
+ if not is_impossible:
666
+ if is_training:
667
+ answer = qa["answers"][0]
668
+ answer_text = answer["text"]
669
+ start_position_character = answer["answer_start"]
670
+ else:
671
+ answers = qa["answers"]
672
+
673
+ example = SquadExample(
674
+ qas_id=qas_id,
675
+ question_text=question_text,
676
+ context_text=context_text,
677
+ answer_text=answer_text,
678
+ start_position_character=start_position_character,
679
+ title=title,
680
+ is_impossible=is_impossible,
681
+ answers=answers,
682
+ )
683
+ examples.append(example)
684
+ return examples
685
+
686
+
687
+ class SquadV1Processor(SquadProcessor):
688
+ train_file = "train-v1.1.json"
689
+ dev_file = "dev-v1.1.json"
690
+
691
+
692
+ class SquadV2Processor(SquadProcessor):
693
+ train_file = "train-v2.0.json"
694
+ dev_file = "dev-v2.0.json"
695
+
696
+
697
+ class SquadExample:
698
+ """
699
+ A single training/test example for the Squad dataset, as loaded from disk.
700
+
701
+ Args:
702
+ qas_id: The example's unique identifier
703
+ question_text: The question string
704
+ context_text: The context string
705
+ answer_text: The answer string
706
+ start_position_character: The character position of the start of the answer
707
+ title: The title of the example
708
+ answers: None by default, this is used during evaluation. Holds answers as well as their start positions.
709
+ is_impossible: False by default, set to True if the example has no possible answer.
710
+ """
711
+
712
+ def __init__(
713
+ self,
714
+ qas_id,
715
+ question_text,
716
+ context_text,
717
+ answer_text,
718
+ start_position_character,
719
+ title,
720
+ answers=[],
721
+ is_impossible=False,
722
+ ):
723
+ self.qas_id = qas_id
724
+ self.question_text = question_text
725
+ self.context_text = context_text
726
+ self.answer_text = answer_text
727
+ self.title = title
728
+ self.is_impossible = is_impossible
729
+ self.answers = answers
730
+
731
+ self.start_position, self.end_position = 0, 0
732
+
733
+ doc_tokens = []
734
+ char_to_word_offset = []
735
+ prev_is_whitespace = True
736
+
737
+ # Split on whitespace so that different tokens may be attributed to their original position.
738
+ for c in self.context_text:
739
+ if _is_whitespace(c):
740
+ prev_is_whitespace = True
741
+ else:
742
+ if prev_is_whitespace:
743
+ doc_tokens.append(c)
744
+ else:
745
+ doc_tokens[-1] += c
746
+ prev_is_whitespace = False
747
+ char_to_word_offset.append(len(doc_tokens) - 1)
748
+
749
+ self.doc_tokens = doc_tokens
750
+ self.char_to_word_offset = char_to_word_offset
751
+
752
+ # Start and end positions only has a value during evaluation.
753
+ if start_position_character is not None and not is_impossible:
754
+ self.start_position = char_to_word_offset[start_position_character]
755
+ self.end_position = char_to_word_offset[
756
+ min(start_position_character + len(answer_text) - 1, len(char_to_word_offset) - 1)
757
+ ]
758
+
759
+
760
+ class SquadFeatures:
761
+ """
762
+ Single squad example features to be fed to a model. Those features are model-specific and can be crafted from
763
+ [`~data.processors.squad.SquadExample`] using the
764
+ :method:*~transformers.data.processors.squad.squad_convert_examples_to_features* method.
765
+
766
+ Args:
767
+ input_ids: Indices of input sequence tokens in the vocabulary.
768
+ attention_mask: Mask to avoid performing attention on padding token indices.
769
+ token_type_ids: Segment token indices to indicate first and second portions of the inputs.
770
+ cls_index: the index of the CLS token.
771
+ p_mask: Mask identifying tokens that can be answers vs. tokens that cannot.
772
+ Mask with 1 for tokens than cannot be in the answer and 0 for token that can be in an answer
773
+ example_index: the index of the example
774
+ unique_id: The unique Feature identifier
775
+ paragraph_len: The length of the context
776
+ token_is_max_context:
777
+ List of booleans identifying which tokens have their maximum context in this feature object. If a token
778
+ does not have their maximum context in this feature object, it means that another feature object has more
779
+ information related to that token and should be prioritized over this feature for that token.
780
+ tokens: list of tokens corresponding to the input ids
781
+ token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer.
782
+ start_position: start of the answer token index
783
+ end_position: end of the answer token index
784
+ encoding: optionally store the BatchEncoding with the fast-tokenizer alignment methods.
785
+ """
786
+
787
+ def __init__(
788
+ self,
789
+ input_ids,
790
+ attention_mask,
791
+ token_type_ids,
792
+ cls_index,
793
+ p_mask,
794
+ example_index,
795
+ unique_id,
796
+ paragraph_len,
797
+ token_is_max_context,
798
+ tokens,
799
+ token_to_orig_map,
800
+ start_position,
801
+ end_position,
802
+ is_impossible,
803
+ qas_id: str = None,
804
+ encoding: BatchEncoding = None,
805
+ ):
806
+ self.input_ids = input_ids
807
+ self.attention_mask = attention_mask
808
+ self.token_type_ids = token_type_ids
809
+ self.cls_index = cls_index
810
+ self.p_mask = p_mask
811
+
812
+ self.example_index = example_index
813
+ self.unique_id = unique_id
814
+ self.paragraph_len = paragraph_len
815
+ self.token_is_max_context = token_is_max_context
816
+ self.tokens = tokens
817
+ self.token_to_orig_map = token_to_orig_map
818
+
819
+ self.start_position = start_position
820
+ self.end_position = end_position
821
+ self.is_impossible = is_impossible
822
+ self.qas_id = qas_id
823
+
824
+ self.encoding = encoding
825
+
826
+
827
+ class SquadResult:
828
+ """
829
+ Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset.
830
+
831
+ Args:
832
+ unique_id: The unique identifier corresponding to that example.
833
+ start_logits: The logits corresponding to the start of the answer
834
+ end_logits: The logits corresponding to the end of the answer
835
+ """
836
+
837
+ def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None):
838
+ self.start_logits = start_logits
839
+ self.end_logits = end_logits
840
+ self.unique_id = unique_id
841
+
842
+ if start_top_index:
843
+ self.start_top_index = start_top_index
844
+ self.end_top_index = end_top_index
845
+ self.cls_logits = cls_logits
transformers_4_35_0/data/processors/utils.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import csv
18
+ import dataclasses
19
+ import json
20
+ from dataclasses import dataclass
21
+ from typing import List, Optional, Union
22
+
23
+ from ...utils import is_tf_available, is_torch_available, logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ @dataclass
30
+ class InputExample:
31
+ """
32
+ A single training/test example for simple sequence classification.
33
+
34
+ Args:
35
+ guid: Unique id for the example.
36
+ text_a: string. The untokenized text of the first sequence. For single
37
+ sequence tasks, only this sequence must be specified.
38
+ text_b: (Optional) string. The untokenized text of the second sequence.
39
+ Only must be specified for sequence pair tasks.
40
+ label: (Optional) string. The label of the example. This should be
41
+ specified for train and dev examples, but not for test examples.
42
+ """
43
+
44
+ guid: str
45
+ text_a: str
46
+ text_b: Optional[str] = None
47
+ label: Optional[str] = None
48
+
49
+ def to_json_string(self):
50
+ """Serializes this instance to a JSON string."""
51
+ return json.dumps(dataclasses.asdict(self), indent=2) + "\n"
52
+
53
+
54
+ @dataclass(frozen=True)
55
+ class InputFeatures:
56
+ """
57
+ A single set of features of data. Property names are the same names as the corresponding inputs to a model.
58
+
59
+ Args:
60
+ input_ids: Indices of input sequence tokens in the vocabulary.
61
+ attention_mask: Mask to avoid performing attention on padding token indices.
62
+ Mask values selected in `[0, 1]`: Usually `1` for tokens that are NOT MASKED, `0` for MASKED (padded)
63
+ tokens.
64
+ token_type_ids: (Optional) Segment token indices to indicate first and second
65
+ portions of the inputs. Only some models use them.
66
+ label: (Optional) Label corresponding to the input. Int for classification problems,
67
+ float for regression problems.
68
+ """
69
+
70
+ input_ids: List[int]
71
+ attention_mask: Optional[List[int]] = None
72
+ token_type_ids: Optional[List[int]] = None
73
+ label: Optional[Union[int, float]] = None
74
+
75
+ def to_json_string(self):
76
+ """Serializes this instance to a JSON string."""
77
+ return json.dumps(dataclasses.asdict(self)) + "\n"
78
+
79
+
80
+ class DataProcessor:
81
+ """Base class for data converters for sequence classification data sets."""
82
+
83
+ def get_example_from_tensor_dict(self, tensor_dict):
84
+ """
85
+ Gets an example from a dict with tensorflow tensors.
86
+
87
+ Args:
88
+ tensor_dict: Keys and values should match the corresponding Glue
89
+ tensorflow_dataset examples.
90
+ """
91
+ raise NotImplementedError()
92
+
93
+ def get_train_examples(self, data_dir):
94
+ """Gets a collection of [`InputExample`] for the train set."""
95
+ raise NotImplementedError()
96
+
97
+ def get_dev_examples(self, data_dir):
98
+ """Gets a collection of [`InputExample`] for the dev set."""
99
+ raise NotImplementedError()
100
+
101
+ def get_test_examples(self, data_dir):
102
+ """Gets a collection of [`InputExample`] for the test set."""
103
+ raise NotImplementedError()
104
+
105
+ def get_labels(self):
106
+ """Gets the list of labels for this data set."""
107
+ raise NotImplementedError()
108
+
109
+ def tfds_map(self, example):
110
+ """
111
+ Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are. This method converts
112
+ examples to the correct format.
113
+ """
114
+ if len(self.get_labels()) > 1:
115
+ example.label = self.get_labels()[int(example.label)]
116
+ return example
117
+
118
+ @classmethod
119
+ def _read_tsv(cls, input_file, quotechar=None):
120
+ """Reads a tab separated value file."""
121
+ with open(input_file, "r", encoding="utf-8-sig") as f:
122
+ return list(csv.reader(f, delimiter="\t", quotechar=quotechar))
123
+
124
+
125
+ class SingleSentenceClassificationProcessor(DataProcessor):
126
+ """Generic processor for a single sentence classification data set."""
127
+
128
+ def __init__(self, labels=None, examples=None, mode="classification", verbose=False):
129
+ self.labels = [] if labels is None else labels
130
+ self.examples = [] if examples is None else examples
131
+ self.mode = mode
132
+ self.verbose = verbose
133
+
134
+ def __len__(self):
135
+ return len(self.examples)
136
+
137
+ def __getitem__(self, idx):
138
+ if isinstance(idx, slice):
139
+ return SingleSentenceClassificationProcessor(labels=self.labels, examples=self.examples[idx])
140
+ return self.examples[idx]
141
+
142
+ @classmethod
143
+ def create_from_csv(
144
+ cls, file_name, split_name="", column_label=0, column_text=1, column_id=None, skip_first_row=False, **kwargs
145
+ ):
146
+ processor = cls(**kwargs)
147
+ processor.add_examples_from_csv(
148
+ file_name,
149
+ split_name=split_name,
150
+ column_label=column_label,
151
+ column_text=column_text,
152
+ column_id=column_id,
153
+ skip_first_row=skip_first_row,
154
+ overwrite_labels=True,
155
+ overwrite_examples=True,
156
+ )
157
+ return processor
158
+
159
+ @classmethod
160
+ def create_from_examples(cls, texts_or_text_and_labels, labels=None, **kwargs):
161
+ processor = cls(**kwargs)
162
+ processor.add_examples(texts_or_text_and_labels, labels=labels)
163
+ return processor
164
+
165
+ def add_examples_from_csv(
166
+ self,
167
+ file_name,
168
+ split_name="",
169
+ column_label=0,
170
+ column_text=1,
171
+ column_id=None,
172
+ skip_first_row=False,
173
+ overwrite_labels=False,
174
+ overwrite_examples=False,
175
+ ):
176
+ lines = self._read_tsv(file_name)
177
+ if skip_first_row:
178
+ lines = lines[1:]
179
+ texts = []
180
+ labels = []
181
+ ids = []
182
+ for i, line in enumerate(lines):
183
+ texts.append(line[column_text])
184
+ labels.append(line[column_label])
185
+ if column_id is not None:
186
+ ids.append(line[column_id])
187
+ else:
188
+ guid = f"{split_name}-{i}" if split_name else str(i)
189
+ ids.append(guid)
190
+
191
+ return self.add_examples(
192
+ texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples
193
+ )
194
+
195
+ def add_examples(
196
+ self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
197
+ ):
198
+ if labels is not None and len(texts_or_text_and_labels) != len(labels):
199
+ raise ValueError(
200
+ f"Text and labels have mismatched lengths {len(texts_or_text_and_labels)} and {len(labels)}"
201
+ )
202
+ if ids is not None and len(texts_or_text_and_labels) != len(ids):
203
+ raise ValueError(f"Text and ids have mismatched lengths {len(texts_or_text_and_labels)} and {len(ids)}")
204
+ if ids is None:
205
+ ids = [None] * len(texts_or_text_and_labels)
206
+ if labels is None:
207
+ labels = [None] * len(texts_or_text_and_labels)
208
+ examples = []
209
+ added_labels = set()
210
+ for text_or_text_and_label, label, guid in zip(texts_or_text_and_labels, labels, ids):
211
+ if isinstance(text_or_text_and_label, (tuple, list)) and label is None:
212
+ text, label = text_or_text_and_label
213
+ else:
214
+ text = text_or_text_and_label
215
+ added_labels.add(label)
216
+ examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label))
217
+
218
+ # Update examples
219
+ if overwrite_examples:
220
+ self.examples = examples
221
+ else:
222
+ self.examples.extend(examples)
223
+
224
+ # Update labels
225
+ if overwrite_labels:
226
+ self.labels = list(added_labels)
227
+ else:
228
+ self.labels = list(set(self.labels).union(added_labels))
229
+
230
+ return self.examples
231
+
232
+ def get_features(
233
+ self,
234
+ tokenizer,
235
+ max_length=None,
236
+ pad_on_left=False,
237
+ pad_token=0,
238
+ mask_padding_with_zero=True,
239
+ return_tensors=None,
240
+ ):
241
+ """
242
+ Convert examples in a list of `InputFeatures`
243
+
244
+ Args:
245
+ tokenizer: Instance of a tokenizer that will tokenize the examples
246
+ max_length: Maximum example length
247
+ pad_on_left: If set to `True`, the examples will be padded on the left rather than on the right (default)
248
+ pad_token: Padding token
249
+ mask_padding_with_zero: If set to `True`, the attention mask will be filled by `1` for actual values
250
+ and by `0` for padded values. If set to `False`, inverts it (`1` for padded values, `0` for actual
251
+ values)
252
+
253
+ Returns:
254
+ If the `examples` input is a `tf.data.Dataset`, will return a `tf.data.Dataset` containing the
255
+ task-specific features. If the input is a list of `InputExamples`, will return a list of task-specific
256
+ `InputFeatures` which can be fed to the model.
257
+
258
+ """
259
+ if max_length is None:
260
+ max_length = tokenizer.max_len
261
+
262
+ label_map = {label: i for i, label in enumerate(self.labels)}
263
+
264
+ all_input_ids = []
265
+ for ex_index, example in enumerate(self.examples):
266
+ if ex_index % 10000 == 0:
267
+ logger.info(f"Tokenizing example {ex_index}")
268
+
269
+ input_ids = tokenizer.encode(
270
+ example.text_a,
271
+ add_special_tokens=True,
272
+ max_length=min(max_length, tokenizer.max_len),
273
+ )
274
+ all_input_ids.append(input_ids)
275
+
276
+ batch_length = max(len(input_ids) for input_ids in all_input_ids)
277
+
278
+ features = []
279
+ for ex_index, (input_ids, example) in enumerate(zip(all_input_ids, self.examples)):
280
+ if ex_index % 10000 == 0:
281
+ logger.info(f"Writing example {ex_index}/{len(self.examples)}")
282
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
283
+ # tokens are attended to.
284
+ attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
285
+
286
+ # Zero-pad up to the sequence length.
287
+ padding_length = batch_length - len(input_ids)
288
+ if pad_on_left:
289
+ input_ids = ([pad_token] * padding_length) + input_ids
290
+ attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
291
+ else:
292
+ input_ids = input_ids + ([pad_token] * padding_length)
293
+ attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
294
+
295
+ if len(input_ids) != batch_length:
296
+ raise ValueError(f"Error with input length {len(input_ids)} vs {batch_length}")
297
+ if len(attention_mask) != batch_length:
298
+ raise ValueError(f"Error with input length {len(attention_mask)} vs {batch_length}")
299
+
300
+ if self.mode == "classification":
301
+ label = label_map[example.label]
302
+ elif self.mode == "regression":
303
+ label = float(example.label)
304
+ else:
305
+ raise ValueError(self.mode)
306
+
307
+ if ex_index < 5 and self.verbose:
308
+ logger.info("*** Example ***")
309
+ logger.info(f"guid: {example.guid}")
310
+ logger.info(f"input_ids: {' '.join([str(x) for x in input_ids])}")
311
+ logger.info(f"attention_mask: {' '.join([str(x) for x in attention_mask])}")
312
+ logger.info(f"label: {example.label} (id = {label})")
313
+
314
+ features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, label=label))
315
+
316
+ if return_tensors is None:
317
+ return features
318
+ elif return_tensors == "tf":
319
+ if not is_tf_available():
320
+ raise RuntimeError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
321
+ import tensorflow as tf
322
+
323
+ def gen():
324
+ for ex in features:
325
+ yield ({"input_ids": ex.input_ids, "attention_mask": ex.attention_mask}, ex.label)
326
+
327
+ dataset = tf.data.Dataset.from_generator(
328
+ gen,
329
+ ({"input_ids": tf.int32, "attention_mask": tf.int32}, tf.int64),
330
+ ({"input_ids": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None])}, tf.TensorShape([])),
331
+ )
332
+ return dataset
333
+ elif return_tensors == "pt":
334
+ if not is_torch_available():
335
+ raise RuntimeError("return_tensors set to 'pt' but PyTorch can't be imported")
336
+ import torch
337
+ from torch.utils.data import TensorDataset
338
+
339
+ all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
340
+ all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
341
+ if self.mode == "classification":
342
+ all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
343
+ elif self.mode == "regression":
344
+ all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
345
+
346
+ dataset = TensorDataset(all_input_ids, all_attention_mask, all_labels)
347
+ return dataset
348
+ else:
349
+ raise ValueError("return_tensors should be one of 'tf' or 'pt'")
transformers_4_35_0/data/processors/xnli.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ XNLI utils (dataset loading and evaluation)"""
17
+
18
+
19
+ import os
20
+
21
+ from ...utils import logging
22
+ from .utils import DataProcessor, InputExample
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class XnliProcessor(DataProcessor):
29
+ """
30
+ Processor for the XNLI dataset. Adapted from
31
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207
32
+ """
33
+
34
+ def __init__(self, language, train_language=None):
35
+ self.language = language
36
+ self.train_language = train_language
37
+
38
+ def get_train_examples(self, data_dir):
39
+ """See base class."""
40
+ lg = self.language if self.train_language is None else self.train_language
41
+ lines = self._read_tsv(os.path.join(data_dir, f"XNLI-MT-1.0/multinli/multinli.train.{lg}.tsv"))
42
+ examples = []
43
+ for i, line in enumerate(lines):
44
+ if i == 0:
45
+ continue
46
+ guid = f"train-{i}"
47
+ text_a = line[0]
48
+ text_b = line[1]
49
+ label = "contradiction" if line[2] == "contradictory" else line[2]
50
+ if not isinstance(text_a, str):
51
+ raise ValueError(f"Training input {text_a} is not a string")
52
+ if not isinstance(text_b, str):
53
+ raise ValueError(f"Training input {text_b} is not a string")
54
+ if not isinstance(label, str):
55
+ raise ValueError(f"Training label {label} is not a string")
56
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
57
+ return examples
58
+
59
+ def get_test_examples(self, data_dir):
60
+ """See base class."""
61
+ lines = self._read_tsv(os.path.join(data_dir, "XNLI-1.0/xnli.test.tsv"))
62
+ examples = []
63
+ for i, line in enumerate(lines):
64
+ if i == 0:
65
+ continue
66
+ language = line[0]
67
+ if language != self.language:
68
+ continue
69
+ guid = f"test-{i}"
70
+ text_a = line[6]
71
+ text_b = line[7]
72
+ label = line[1]
73
+ if not isinstance(text_a, str):
74
+ raise ValueError(f"Training input {text_a} is not a string")
75
+ if not isinstance(text_b, str):
76
+ raise ValueError(f"Training input {text_b} is not a string")
77
+ if not isinstance(label, str):
78
+ raise ValueError(f"Training label {label} is not a string")
79
+ examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
80
+ return examples
81
+
82
+ def get_labels(self):
83
+ """See base class."""
84
+ return ["contradiction", "entailment", "neutral"]
85
+
86
+
87
+ xnli_processors = {
88
+ "xnli": XnliProcessor,
89
+ }
90
+
91
+ xnli_output_modes = {
92
+ "xnli": "classification",
93
+ }
94
+
95
+ xnli_tasks_num_labels = {
96
+ "xnli": 3,
97
+ }
transformers_4_35_0/debug_utils.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import collections
16
+
17
+ from .utils import ExplicitEnum, is_torch_available, logging
18
+
19
+
20
+ if is_torch_available():
21
+ import torch
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class DebugUnderflowOverflow:
28
+ """
29
+ This debug class helps detect and understand where the model starts getting very large or very small, and more
30
+ importantly `nan` or `inf` weight and activation elements.
31
+
32
+ There are 2 working modes:
33
+
34
+ 1. Underflow/overflow detection (default)
35
+ 2. Specific batch absolute min/max tracing without detection
36
+
37
+ Mode 1: Underflow/overflow detection
38
+
39
+ To activate the underflow/overflow detection, initialize the object with the model :
40
+
41
+ ```python
42
+ debug_overflow = DebugUnderflowOverflow(model)
43
+ ```
44
+
45
+ then run the training as normal and if `nan` or `inf` gets detected in at least one of the weight, input or output
46
+ elements this module will throw an exception and will print `max_frames_to_save` frames that lead to this event,
47
+ each frame reporting
48
+
49
+ 1. the fully qualified module name plus the class name whose `forward` was run
50
+ 2. the absolute min and max value of all elements for each module weights, and the inputs and output
51
+
52
+ For example, here is the header and the last few frames in detection report for `google/mt5-small` run in fp16
53
+ mixed precision :
54
+
55
+ ```
56
+ Detected inf/nan during batch_number=0
57
+ Last 21 forward frames:
58
+ abs min abs max metadata
59
+ [...]
60
+ encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
61
+ 2.17e-07 4.50e+00 weight
62
+ 1.79e-06 4.65e+00 input[0]
63
+ 2.68e-06 3.70e+01 output
64
+ encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
65
+ 8.08e-07 2.66e+01 weight
66
+ 1.79e-06 4.65e+00 input[0]
67
+ 1.27e-04 2.37e+02 output
68
+ encoder.block.2.layer.1.DenseReluDense.wo Linear
69
+ 1.01e-06 6.44e+00 weight
70
+ 0.00e+00 9.74e+03 input[0]
71
+ 3.18e-04 6.27e+04 output
72
+ encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
73
+ 1.79e-06 4.65e+00 input[0]
74
+ 3.18e-04 6.27e+04 output
75
+ encoder.block.2.layer.1.dropout Dropout
76
+ 3.18e-04 6.27e+04 input[0]
77
+ 0.00e+00 inf output
78
+ ```
79
+
80
+ You can see here, that `T5DenseGatedGeluDense.forward` resulted in output activations, whose absolute max value was
81
+ around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have `Dropout` which
82
+ renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than
83
+ 64K, and we get an overlow.
84
+
85
+ As you can see it's the previous frames that we need to look into when the numbers start going into very large for
86
+ fp16 numbers.
87
+
88
+ The tracking is done in a forward hook, which gets invoked immediately after `forward` has completed.
89
+
90
+ By default the last 21 frames are printed. You can change the default to adjust for your needs. For example :
91
+
92
+ ```python
93
+ debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)
94
+ ```
95
+
96
+ To validate that you have set up this debugging feature correctly, and you intend to use it in a training that
97
+ may take hours to complete, first run it with normal tracing enabled for one of a few batches as explained in
98
+ the next section.
99
+
100
+
101
+ Mode 2. Specific batch absolute min/max tracing without detection
102
+
103
+ The second work mode is per-batch tracing with the underflow/overflow detection feature turned off.
104
+
105
+ Let's say you want to watch the absolute min and max values for all the ingredients of each `forward` call of a
106
+ given batch, and only do that for batches 1 and 3. Then you instantiate this class as :
107
+
108
+ ```python
109
+ debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3])
110
+ ```
111
+
112
+ And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed.
113
+
114
+ This is helpful if you know that the program starts misbehaving after a certain batch number, so you can
115
+ fast-forward right to that area.
116
+
117
+
118
+ Early stopping:
119
+
120
+ You can also specify the batch number after which to stop the training, with :
121
+
122
+ ```python
123
+ debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3], abort_after_batch_num=3)
124
+ ```
125
+
126
+ This feature is mainly useful in the tracing mode, but you can use it for any mode.
127
+
128
+
129
+ **Performance**:
130
+
131
+ As this module measures absolute `min`/``max` of each weight of the model on every forward it'll slow the training
132
+ down. Therefore remember to turn it off once the debugging needs have been met.
133
+
134
+ Args:
135
+ model (`nn.Module`):
136
+ The model to debug.
137
+ max_frames_to_save (`int`, *optional*, defaults to 21):
138
+ How many frames back to record
139
+ trace_batch_nums(`List[int]`, *optional*, defaults to `[]`):
140
+ Which batch numbers to trace (turns detection off)
141
+ abort_after_batch_num (`int``, *optional*):
142
+ Whether to abort after a certain batch number has finished
143
+ """
144
+
145
+ def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None):
146
+ self.model = model
147
+ self.trace_batch_nums = trace_batch_nums
148
+ self.abort_after_batch_num = abort_after_batch_num
149
+
150
+ # keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence
151
+ self.frames = collections.deque([], max_frames_to_save)
152
+ self.frame = []
153
+ self.batch_number = 0
154
+ self.total_calls = 0
155
+ self.detected_overflow = False
156
+ self.prefix = " "
157
+
158
+ self.analyse_model()
159
+
160
+ self.register_forward_hook()
161
+
162
+ def save_frame(self, frame=None):
163
+ if frame is not None:
164
+ self.expand_frame(frame)
165
+ self.frames.append("\n".join(self.frame))
166
+ self.frame = [] # start a new frame
167
+
168
+ def expand_frame(self, line):
169
+ self.frame.append(line)
170
+
171
+ def trace_frames(self):
172
+ print("\n".join(self.frames))
173
+ self.frames = []
174
+
175
+ def reset_saved_frames(self):
176
+ self.frames = []
177
+
178
+ def dump_saved_frames(self):
179
+ print(f"\nDetected inf/nan during batch_number={self.batch_number}")
180
+ print(f"Last {len(self.frames)} forward frames:")
181
+ print(f"{'abs min':8} {'abs max':8} metadata")
182
+ print("\n".join(self.frames))
183
+ print("\n\n")
184
+ self.frames = []
185
+
186
+ def analyse_model(self):
187
+ # extract the fully qualified module names, to be able to report at run time. e.g.:
188
+ # encoder.block.2.layer.0.SelfAttention.o
189
+ #
190
+ # for shared weights only the first shared module name will be registered
191
+ self.module_names = {m: name for name, m in self.model.named_modules()}
192
+ # self.longest_module_name = max(len(v) for v in self.module_names.values())
193
+
194
+ def analyse_variable(self, var, ctx):
195
+ if torch.is_tensor(var):
196
+ self.expand_frame(get_abs_min_max(var, ctx))
197
+ if detect_overflow(var, ctx):
198
+ self.detected_overflow = True
199
+ elif var is None:
200
+ self.expand_frame(f"{'None':>17} {ctx}")
201
+ else:
202
+ self.expand_frame(f"{'not a tensor':>17} {ctx}")
203
+
204
+ def batch_start_frame(self):
205
+ self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***")
206
+ self.expand_frame(f"{'abs min':8} {'abs max':8} metadata")
207
+
208
+ def batch_end_frame(self):
209
+ self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number-1} ***\n\n")
210
+
211
+ def create_frame(self, module, input, output):
212
+ self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}")
213
+
214
+ # params
215
+ for name, p in module.named_parameters(recurse=False):
216
+ self.analyse_variable(p, name)
217
+
218
+ # inputs
219
+ if isinstance(input, tuple):
220
+ for i, x in enumerate(input):
221
+ self.analyse_variable(x, f"input[{i}]")
222
+ else:
223
+ self.analyse_variable(input, "input")
224
+
225
+ # outputs
226
+ if isinstance(output, tuple):
227
+ for i, x in enumerate(output):
228
+ # possibly a tuple of tuples
229
+ if isinstance(x, tuple):
230
+ for j, y in enumerate(x):
231
+ self.analyse_variable(y, f"output[{i}][{j}]")
232
+ else:
233
+ self.analyse_variable(x, f"output[{i}]")
234
+ else:
235
+ self.analyse_variable(output, "output")
236
+
237
+ self.save_frame()
238
+
239
+ def register_forward_hook(self):
240
+ self.model.apply(self._register_forward_hook)
241
+
242
+ def _register_forward_hook(self, module):
243
+ module.register_forward_hook(self.forward_hook)
244
+
245
+ def forward_hook(self, module, input, output):
246
+ # - input is a tuple of packed inputs (could be non-Tensors)
247
+ # - output could be a Tensor or a tuple of Tensors and non-Tensors
248
+
249
+ last_frame_of_batch = False
250
+
251
+ trace_mode = True if self.batch_number in self.trace_batch_nums else False
252
+ if trace_mode:
253
+ self.reset_saved_frames()
254
+
255
+ if self.total_calls == 0:
256
+ self.batch_start_frame()
257
+ self.total_calls += 1
258
+
259
+ # count batch numbers - the very first forward hook of the batch will be called when the
260
+ # batch completes - i.e. it gets called very last - we know this batch has finished
261
+ if module == self.model:
262
+ self.batch_number += 1
263
+ last_frame_of_batch = True
264
+
265
+ self.create_frame(module, input, output)
266
+
267
+ # if last_frame_of_batch:
268
+ # self.batch_end_frame()
269
+
270
+ if trace_mode:
271
+ self.trace_frames()
272
+
273
+ if last_frame_of_batch:
274
+ self.batch_start_frame()
275
+
276
+ if self.detected_overflow and not trace_mode:
277
+ self.dump_saved_frames()
278
+
279
+ # now we can abort, as it's pointless to continue running
280
+ raise ValueError(
281
+ "DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. "
282
+ "Please scroll up above this traceback to see the activation values prior to this event."
283
+ )
284
+
285
+ # abort after certain batch if requested to do so
286
+ if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num:
287
+ raise ValueError(
288
+ f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to"
289
+ f" `abort_after_batch_num={self.abort_after_batch_num}` arg"
290
+ )
291
+
292
+
293
+ def get_abs_min_max(var, ctx):
294
+ abs_var = var.abs()
295
+ return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}"
296
+
297
+
298
+ def detect_overflow(var, ctx):
299
+ """
300
+ Report whether the tensor contains any `nan` or `inf` entries.
301
+
302
+ This is useful for detecting overflows/underflows and best to call right after the function that did some math that
303
+ modified the tensor in question.
304
+
305
+ This function contains a few other helper features that you can enable and tweak directly if you want to track
306
+ various other things.
307
+
308
+ Args:
309
+ var: the tensor variable to check
310
+ ctx: the message to print as a context
311
+
312
+ Return:
313
+ `True` if `inf` or `nan` was detected, `False` otherwise
314
+ """
315
+ detected = False
316
+ if torch.isnan(var).any().item():
317
+ detected = True
318
+ print(f"{ctx} has nans")
319
+ if torch.isinf(var).any().item():
320
+ detected = True
321
+ print(f"{ctx} has infs")
322
+
323
+ # if needed to monitor large elements can enable the following
324
+ if 0: # and detected:
325
+ n100 = var[torch.ge(var.abs(), 100)]
326
+ if n100.numel() > 0:
327
+ print(f"{ctx}: n100={n100.numel()}")
328
+ n1000 = var[torch.ge(var.abs(), 1000)]
329
+ if n1000.numel() > 0:
330
+ print(f"{ctx}: n1000={n1000.numel()}")
331
+ n10000 = var[torch.ge(var.abs(), 10000)]
332
+ if n10000.numel() > 0:
333
+ print(f"{ctx}: n10000={n10000.numel()}")
334
+
335
+ if 0:
336
+ print(f"min={var.min():9.2e} max={var.max():9.2e}")
337
+
338
+ if 0:
339
+ print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})")
340
+
341
+ return detected
342
+
343
+
344
+ class DebugOption(ExplicitEnum):
345
+ UNDERFLOW_OVERFLOW = "underflow_overflow"
346
+ TPU_METRICS_DEBUG = "tpu_metrics_debug"
transformers_4_35_0/deepspeed.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Integration with Deepspeed - kept for backward compatiblity, if you plan to make any edit, make sure to modify the file
16
+ in `integrations/deepspeed` instead.
17
+
18
+ Check: https://github.com/huggingface/transformers/pull/25599
19
+ """
20
+ import warnings
21
+
22
+
23
+ warnings.warn(
24
+ "transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations",
25
+ FutureWarning,
26
+ )
27
+
28
+ # Backward compatibility imports, to make sure all those objects can be found in integrations/deepspeed
29
+ from .integrations.deepspeed import ( # noqa
30
+ HfDeepSpeedConfig,
31
+ HfTrainerDeepSpeedConfig,
32
+ deepspeed_config,
33
+ deepspeed_init,
34
+ deepspeed_load_checkpoint,
35
+ deepspeed_optim_sched,
36
+ is_deepspeed_available,
37
+ is_deepspeed_zero3_enabled,
38
+ set_hf_deepspeed_config,
39
+ unset_hf_deepspeed_config,
40
+ )
transformers_4_35_0/dependency_versions_check.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .dependency_versions_table import deps
16
+ from .utils.versions import require_version, require_version_core
17
+
18
+
19
+ # define which module versions we always want to check at run time
20
+ # (usually the ones defined in `install_requires` in setup.py)
21
+ #
22
+ # order specific notes:
23
+ # - tqdm must be checked before tokenizers
24
+
25
+ pkgs_to_check_at_runtime = [
26
+ "python",
27
+ "tqdm",
28
+ "regex",
29
+ "requests",
30
+ "packaging",
31
+ "filelock",
32
+ "numpy",
33
+ "tokenizers",
34
+ "huggingface-hub",
35
+ "safetensors",
36
+ "accelerate",
37
+ "pyyaml",
38
+ ]
39
+
40
+ for pkg in pkgs_to_check_at_runtime:
41
+ if pkg in deps:
42
+ if pkg == "tokenizers":
43
+ # must be loaded here, or else tqdm check may fail
44
+ from .utils import is_tokenizers_available
45
+
46
+ if not is_tokenizers_available():
47
+ continue # not required, check version only if installed
48
+ elif pkg == "accelerate":
49
+ # must be loaded here, or else tqdm check may fail
50
+ from .utils import is_accelerate_available
51
+
52
+ # Maybe switch to is_torch_available in the future here so that Accelerate is hard dep of
53
+ # Transformers with PyTorch
54
+ if not is_accelerate_available():
55
+ continue # not required, check version only if installed
56
+
57
+ require_version_core(deps[pkg])
58
+ else:
59
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
60
+
61
+
62
+ def dep_version_check(pkg, hint=None):
63
+ require_version(deps[pkg], hint)
transformers_4_35_0/dependency_versions_table.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
+ # 1. modify the `_deps` dict in setup.py
3
+ # 2. run `make deps_table_update``
4
+ deps = {
5
+ "Pillow": "Pillow<10.0.0",
6
+ "accelerate": "accelerate>=0.20.3",
7
+ "av": "av==9.2.0",
8
+ "beautifulsoup4": "beautifulsoup4",
9
+ "black": "black~=23.1",
10
+ "codecarbon": "codecarbon==1.2.0",
11
+ "cookiecutter": "cookiecutter==1.7.3",
12
+ "dataclasses": "dataclasses",
13
+ "datasets": "datasets!=2.5.0",
14
+ "decord": "decord==0.6.0",
15
+ "deepspeed": "deepspeed>=0.9.3",
16
+ "diffusers": "diffusers",
17
+ "dill": "dill<0.3.5",
18
+ "evaluate": "evaluate>=0.2.0",
19
+ "faiss-cpu": "faiss-cpu",
20
+ "fastapi": "fastapi",
21
+ "filelock": "filelock",
22
+ "flax": "flax>=0.4.1,<=0.7.0",
23
+ "ftfy": "ftfy",
24
+ "fugashi": "fugashi>=1.0",
25
+ "GitPython": "GitPython<3.1.19",
26
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
27
+ "huggingface-hub": "huggingface-hub>=0.16.4,<1.0",
28
+ "importlib_metadata": "importlib_metadata",
29
+ "ipadic": "ipadic>=1.0.0,<2.0",
30
+ "isort": "isort>=5.5.4",
31
+ "jax": "jax>=0.4.1,<=0.4.13",
32
+ "jaxlib": "jaxlib>=0.4.1,<=0.4.13",
33
+ "jieba": "jieba",
34
+ "kenlm": "kenlm",
35
+ "keras-nlp": "keras-nlp>=0.3.1",
36
+ "librosa": "librosa",
37
+ "nltk": "nltk",
38
+ "natten": "natten>=0.14.6",
39
+ "numpy": "numpy>=1.17",
40
+ "onnxconverter-common": "onnxconverter-common",
41
+ "onnxruntime-tools": "onnxruntime-tools>=1.4.2",
42
+ "onnxruntime": "onnxruntime>=1.4.0",
43
+ "opencv-python": "opencv-python",
44
+ "optuna": "optuna",
45
+ "optax": "optax>=0.0.8,<=0.1.4",
46
+ "packaging": "packaging>=20.0",
47
+ "parameterized": "parameterized",
48
+ "phonemizer": "phonemizer",
49
+ "protobuf": "protobuf",
50
+ "psutil": "psutil",
51
+ "pyyaml": "pyyaml>=5.1",
52
+ "pydantic": "pydantic<2",
53
+ "pytest": "pytest>=7.2.0",
54
+ "pytest-timeout": "pytest-timeout",
55
+ "pytest-xdist": "pytest-xdist",
56
+ "python": "python>=3.8.0",
57
+ "ray[tune]": "ray[tune]",
58
+ "regex": "regex!=2019.12.17",
59
+ "requests": "requests",
60
+ "rhoknp": "rhoknp>=1.1.0,<1.3.1",
61
+ "rjieba": "rjieba",
62
+ "rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
63
+ "ruff": "ruff>=0.0.241,<=0.0.259",
64
+ "sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
65
+ "sacremoses": "sacremoses",
66
+ "safetensors": "safetensors>=0.3.1",
67
+ "sagemaker": "sagemaker>=2.31.0",
68
+ "scikit-learn": "scikit-learn",
69
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
70
+ "sigopt": "sigopt",
71
+ "starlette": "starlette",
72
+ "sudachipy": "sudachipy>=0.6.6",
73
+ "sudachidict_core": "sudachidict_core>=20220729",
74
+ "tensorflow-cpu": "tensorflow-cpu>=2.6,<2.15",
75
+ "tensorflow": "tensorflow>=2.6,<2.15",
76
+ "tensorflow-text": "tensorflow-text<2.15",
77
+ "tf2onnx": "tf2onnx",
78
+ "timeout-decorator": "timeout-decorator",
79
+ "timm": "timm",
80
+ "tokenizers": "tokenizers>=0.14,<0.15",
81
+ "torch": "torch>=1.10,!=1.12.0",
82
+ "torchaudio": "torchaudio",
83
+ "torchvision": "torchvision",
84
+ "pyctcdecode": "pyctcdecode>=0.4.0",
85
+ "tqdm": "tqdm>=4.27",
86
+ "unidic": "unidic>=1.0.2",
87
+ "unidic_lite": "unidic_lite>=1.0.7",
88
+ "urllib3": "urllib3<2.0.0",
89
+ "uvicorn": "uvicorn",
90
+ }
transformers_4_35_0/dynamic_module_utils.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Utilities to dynamically load objects from the Hub."""
16
+ import filecmp
17
+ import importlib
18
+ import os
19
+ import re
20
+ import shutil
21
+ import signal
22
+ import sys
23
+ import typing
24
+ import warnings
25
+ from pathlib import Path
26
+ from typing import Any, Dict, List, Optional, Union
27
+
28
+ from .utils import (
29
+ HF_MODULES_CACHE,
30
+ TRANSFORMERS_DYNAMIC_MODULE_NAME,
31
+ cached_file,
32
+ extract_commit_hash,
33
+ is_offline_mode,
34
+ logging,
35
+ try_to_load_from_cache,
36
+ )
37
+
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ def init_hf_modules():
43
+ """
44
+ Creates the cache directory for modules with an init, and adds it to the Python path.
45
+ """
46
+ # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
47
+ if HF_MODULES_CACHE in sys.path:
48
+ return
49
+
50
+ sys.path.append(HF_MODULES_CACHE)
51
+ os.makedirs(HF_MODULES_CACHE, exist_ok=True)
52
+ init_path = Path(HF_MODULES_CACHE) / "__init__.py"
53
+ if not init_path.exists():
54
+ init_path.touch()
55
+ importlib.invalidate_caches()
56
+
57
+
58
+ def create_dynamic_module(name: Union[str, os.PathLike]):
59
+ """
60
+ Creates a dynamic module in the cache directory for modules.
61
+
62
+ Args:
63
+ name (`str` or `os.PathLike`):
64
+ The name of the dynamic module to create.
65
+ """
66
+ init_hf_modules()
67
+ dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve()
68
+ # If the parent module does not exist yet, recursively create it.
69
+ if not dynamic_module_path.parent.exists():
70
+ create_dynamic_module(dynamic_module_path.parent)
71
+ os.makedirs(dynamic_module_path, exist_ok=True)
72
+ init_path = dynamic_module_path / "__init__.py"
73
+ if not init_path.exists():
74
+ init_path.touch()
75
+ # It is extremely important to invalidate the cache when we change stuff in those modules, or users end up
76
+ # with errors about module that do not exist. Same for all other `invalidate_caches` in this file.
77
+ importlib.invalidate_caches()
78
+
79
+
80
+ def get_relative_imports(module_file: Union[str, os.PathLike]) -> List[str]:
81
+ """
82
+ Get the list of modules that are relatively imported in a module file.
83
+
84
+ Args:
85
+ module_file (`str` or `os.PathLike`): The module file to inspect.
86
+
87
+ Returns:
88
+ `List[str]`: The list of relative imports in the module.
89
+ """
90
+ with open(module_file, "r", encoding="utf-8") as f:
91
+ content = f.read()
92
+
93
+ # Imports of the form `import .xxx`
94
+ relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
95
+ # Imports of the form `from .xxx import yyy`
96
+ relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
97
+ # Unique-ify
98
+ return list(set(relative_imports))
99
+
100
+
101
+ def get_relative_import_files(module_file: Union[str, os.PathLike]) -> List[str]:
102
+ """
103
+ Get the list of all files that are needed for a given module. Note that this function recurses through the relative
104
+ imports (if a imports b and b imports c, it will return module files for b and c).
105
+
106
+ Args:
107
+ module_file (`str` or `os.PathLike`): The module file to inspect.
108
+
109
+ Returns:
110
+ `List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
111
+ of module files a given module needs.
112
+ """
113
+ no_change = False
114
+ files_to_check = [module_file]
115
+ all_relative_imports = []
116
+
117
+ # Let's recurse through all relative imports
118
+ while not no_change:
119
+ new_imports = []
120
+ for f in files_to_check:
121
+ new_imports.extend(get_relative_imports(f))
122
+
123
+ module_path = Path(module_file).parent
124
+ new_import_files = [str(module_path / m) for m in new_imports]
125
+ new_import_files = [f for f in new_import_files if f not in all_relative_imports]
126
+ files_to_check = [f"{f}.py" for f in new_import_files]
127
+
128
+ no_change = len(new_import_files) == 0
129
+ all_relative_imports.extend(files_to_check)
130
+
131
+ return all_relative_imports
132
+
133
+
134
+ def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
135
+ """
136
+ Extracts all the libraries (not relative imports this time) that are imported in a file.
137
+
138
+ Args:
139
+ filename (`str` or `os.PathLike`): The module file to inspect.
140
+
141
+ Returns:
142
+ `List[str]`: The list of all packages required to use the input module.
143
+ """
144
+ with open(filename, "r", encoding="utf-8") as f:
145
+ content = f.read()
146
+
147
+ # filter out try/except block so in custom code we can have try/except imports
148
+ content = re.sub(r"\s*try\s*:\s*.*?\s*except\s*.*?:", "", content, flags=re.MULTILINE | re.DOTALL)
149
+
150
+ # Imports of the form `import xxx`
151
+ imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
152
+ # Imports of the form `from xxx import yyy`
153
+ imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
154
+ # Only keep the top-level module
155
+ imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
156
+ return list(set(imports))
157
+
158
+
159
+ def check_imports(filename: Union[str, os.PathLike]) -> List[str]:
160
+ """
161
+ Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a
162
+ library is missing.
163
+
164
+ Args:
165
+ filename (`str` or `os.PathLike`): The module file to check.
166
+
167
+ Returns:
168
+ `List[str]`: The list of relative imports in the file.
169
+ """
170
+ imports = get_imports(filename)
171
+ missing_packages = []
172
+ for imp in imports:
173
+ try:
174
+ importlib.import_module(imp)
175
+ except ImportError:
176
+ missing_packages.append(imp)
177
+
178
+ if len(missing_packages) > 0:
179
+ raise ImportError(
180
+ "This modeling file requires the following packages that were not found in your environment: "
181
+ f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
182
+ )
183
+
184
+ return get_relative_imports(filename)
185
+
186
+
187
+ def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) -> typing.Type:
188
+ """
189
+ Import a module on the cache directory for modules and extract a class from it.
190
+
191
+ Args:
192
+ class_name (`str`): The name of the class to import.
193
+ module_path (`str` or `os.PathLike`): The path to the module to import.
194
+
195
+ Returns:
196
+ `typing.Type`: The class looked for.
197
+ """
198
+ module_path = module_path.replace(os.path.sep, ".")
199
+ module = importlib.import_module(module_path)
200
+ return getattr(module, class_name)
201
+
202
+
203
+ def get_cached_module_file(
204
+ pretrained_model_name_or_path: Union[str, os.PathLike],
205
+ module_file: str,
206
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
207
+ force_download: bool = False,
208
+ resume_download: bool = False,
209
+ proxies: Optional[Dict[str, str]] = None,
210
+ token: Optional[Union[bool, str]] = None,
211
+ revision: Optional[str] = None,
212
+ local_files_only: bool = False,
213
+ repo_type: Optional[str] = None,
214
+ _commit_hash: Optional[str] = None,
215
+ **deprecated_kwargs,
216
+ ) -> str:
217
+ """
218
+ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
219
+ Transformers module.
220
+
221
+ Args:
222
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
223
+ This can be either:
224
+
225
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
226
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
227
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
228
+ - a path to a *directory* containing a configuration file saved using the
229
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
230
+
231
+ module_file (`str`):
232
+ The name of the module file containing the class to look for.
233
+ cache_dir (`str` or `os.PathLike`, *optional*):
234
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
235
+ cache should not be used.
236
+ force_download (`bool`, *optional*, defaults to `False`):
237
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
238
+ exist.
239
+ resume_download (`bool`, *optional*, defaults to `False`):
240
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
241
+ proxies (`Dict[str, str]`, *optional*):
242
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
243
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
244
+ token (`str` or *bool*, *optional*):
245
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
246
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
247
+ revision (`str`, *optional*, defaults to `"main"`):
248
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
249
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
250
+ identifier allowed by git.
251
+ local_files_only (`bool`, *optional*, defaults to `False`):
252
+ If `True`, will only try to load the tokenizer configuration from local files.
253
+ repo_type (`str`, *optional*):
254
+ Specify the repo type (useful when downloading from a space for instance).
255
+
256
+ <Tip>
257
+
258
+ Passing `token=True` is required when you want to use a private model.
259
+
260
+ </Tip>
261
+
262
+ Returns:
263
+ `str`: The path to the module inside the cache.
264
+ """
265
+ use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
266
+ if use_auth_token is not None:
267
+ warnings.warn(
268
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
269
+ )
270
+ if token is not None:
271
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
272
+ token = use_auth_token
273
+
274
+ if is_offline_mode() and not local_files_only:
275
+ logger.info("Offline mode: forcing local_files_only=True")
276
+ local_files_only = True
277
+
278
+ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
279
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
280
+ is_local = os.path.isdir(pretrained_model_name_or_path)
281
+ if is_local:
282
+ submodule = os.path.basename(pretrained_model_name_or_path)
283
+ else:
284
+ submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
285
+ cached_module = try_to_load_from_cache(
286
+ pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
287
+ )
288
+
289
+ new_files = []
290
+ try:
291
+ # Load from URL or cache if already cached
292
+ resolved_module_file = cached_file(
293
+ pretrained_model_name_or_path,
294
+ module_file,
295
+ cache_dir=cache_dir,
296
+ force_download=force_download,
297
+ proxies=proxies,
298
+ resume_download=resume_download,
299
+ local_files_only=local_files_only,
300
+ token=token,
301
+ revision=revision,
302
+ repo_type=repo_type,
303
+ _commit_hash=_commit_hash,
304
+ )
305
+ if not is_local and cached_module != resolved_module_file:
306
+ new_files.append(module_file)
307
+
308
+ except EnvironmentError:
309
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
310
+ raise
311
+
312
+ # Check we have all the requirements in our environment
313
+ modules_needed = check_imports(resolved_module_file)
314
+
315
+ # Now we move the module inside our cached dynamic modules.
316
+ full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
317
+ create_dynamic_module(full_submodule)
318
+ submodule_path = Path(HF_MODULES_CACHE) / full_submodule
319
+ if submodule == os.path.basename(pretrained_model_name_or_path):
320
+ # We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or
321
+ # has changed since last copy.
322
+ if not (submodule_path / module_file).exists() or not filecmp.cmp(
323
+ resolved_module_file, str(submodule_path / module_file)
324
+ ):
325
+ shutil.copy(resolved_module_file, submodule_path / module_file)
326
+ importlib.invalidate_caches()
327
+ for module_needed in modules_needed:
328
+ module_needed = f"{module_needed}.py"
329
+ module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed)
330
+ if not (submodule_path / module_needed).exists() or not filecmp.cmp(
331
+ module_needed_file, str(submodule_path / module_needed)
332
+ ):
333
+ shutil.copy(module_needed_file, submodule_path / module_needed)
334
+ importlib.invalidate_caches()
335
+ else:
336
+ # Get the commit hash
337
+ commit_hash = extract_commit_hash(resolved_module_file, _commit_hash)
338
+
339
+ # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
340
+ # benefit of versioning.
341
+ submodule_path = submodule_path / commit_hash
342
+ full_submodule = full_submodule + os.path.sep + commit_hash
343
+ create_dynamic_module(full_submodule)
344
+
345
+ if not (submodule_path / module_file).exists():
346
+ shutil.copy(resolved_module_file, submodule_path / module_file)
347
+ importlib.invalidate_caches()
348
+ # Make sure we also have every file with relative
349
+ for module_needed in modules_needed:
350
+ if not (submodule_path / f"{module_needed}.py").exists():
351
+ get_cached_module_file(
352
+ pretrained_model_name_or_path,
353
+ f"{module_needed}.py",
354
+ cache_dir=cache_dir,
355
+ force_download=force_download,
356
+ resume_download=resume_download,
357
+ proxies=proxies,
358
+ token=token,
359
+ revision=revision,
360
+ local_files_only=local_files_only,
361
+ _commit_hash=commit_hash,
362
+ )
363
+ new_files.append(f"{module_needed}.py")
364
+
365
+ if len(new_files) > 0 and revision is None:
366
+ new_files = "\n".join([f"- {f}" for f in new_files])
367
+ repo_type_str = "" if repo_type is None else f"{repo_type}s/"
368
+ url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
369
+ logger.warning(
370
+ f"A new version of the following files was downloaded from {url}:\n{new_files}"
371
+ "\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
372
+ "versions of the code file, you can pin a revision."
373
+ )
374
+
375
+ return os.path.join(full_submodule, module_file)
376
+
377
+
378
+ def get_class_from_dynamic_module(
379
+ class_reference: str,
380
+ pretrained_model_name_or_path: Union[str, os.PathLike],
381
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
382
+ force_download: bool = False,
383
+ resume_download: bool = False,
384
+ proxies: Optional[Dict[str, str]] = None,
385
+ token: Optional[Union[bool, str]] = None,
386
+ revision: Optional[str] = None,
387
+ local_files_only: bool = False,
388
+ repo_type: Optional[str] = None,
389
+ code_revision: Optional[str] = None,
390
+ **kwargs,
391
+ ) -> typing.Type:
392
+ """
393
+ Extracts a class from a module file, present in the local folder or repository of a model.
394
+
395
+ <Tip warning={true}>
396
+
397
+ Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
398
+ therefore only be called on trusted repos.
399
+
400
+ </Tip>
401
+
402
+ Args:
403
+ class_reference (`str`):
404
+ The full name of the class to load, including its module and optionally its repo.
405
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
406
+ This can be either:
407
+
408
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
409
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
410
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
411
+ - a path to a *directory* containing a configuration file saved using the
412
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
413
+
414
+ This is used when `class_reference` does not specify another repo.
415
+ module_file (`str`):
416
+ The name of the module file containing the class to look for.
417
+ class_name (`str`):
418
+ The name of the class to import in the module.
419
+ cache_dir (`str` or `os.PathLike`, *optional*):
420
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
421
+ cache should not be used.
422
+ force_download (`bool`, *optional*, defaults to `False`):
423
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
424
+ exist.
425
+ resume_download (`bool`, *optional*, defaults to `False`):
426
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
427
+ proxies (`Dict[str, str]`, *optional*):
428
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
429
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
430
+ token (`str` or `bool`, *optional*):
431
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
432
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
433
+ revision (`str`, *optional*, defaults to `"main"`):
434
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
435
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
436
+ identifier allowed by git.
437
+ local_files_only (`bool`, *optional*, defaults to `False`):
438
+ If `True`, will only try to load the tokenizer configuration from local files.
439
+ repo_type (`str`, *optional*):
440
+ Specify the repo type (useful when downloading from a space for instance).
441
+ code_revision (`str`, *optional*, defaults to `"main"`):
442
+ The specific revision to use for the code on the Hub, if the code leaves in a different repository than the
443
+ rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for
444
+ storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
445
+
446
+ <Tip>
447
+
448
+ Passing `token=True` is required when you want to use a private model.
449
+
450
+ </Tip>
451
+
452
+ Returns:
453
+ `typing.Type`: The class, dynamically imported from the module.
454
+
455
+ Examples:
456
+
457
+ ```python
458
+ # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
459
+ # module.
460
+ cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model")
461
+
462
+ # Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this
463
+ # module.
464
+ cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model")
465
+ ```"""
466
+ use_auth_token = kwargs.pop("use_auth_token", None)
467
+ if use_auth_token is not None:
468
+ warnings.warn(
469
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
470
+ )
471
+ if token is not None:
472
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
473
+ token = use_auth_token
474
+
475
+ # Catch the name of the repo if it's specified in `class_reference`
476
+ if "--" in class_reference:
477
+ repo_id, class_reference = class_reference.split("--")
478
+ else:
479
+ repo_id = pretrained_model_name_or_path
480
+ module_file, class_name = class_reference.split(".")
481
+
482
+ if code_revision is None and pretrained_model_name_or_path == repo_id:
483
+ code_revision = revision
484
+ # And lastly we get the class inside our newly created module
485
+ final_module = get_cached_module_file(
486
+ repo_id,
487
+ module_file + ".py",
488
+ cache_dir=cache_dir,
489
+ force_download=force_download,
490
+ resume_download=resume_download,
491
+ proxies=proxies,
492
+ token=token,
493
+ revision=code_revision,
494
+ local_files_only=local_files_only,
495
+ repo_type=repo_type,
496
+ )
497
+ return get_class_in_module(class_name, final_module.replace(".py", ""))
498
+
499
+
500
+ def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]:
501
+ """
502
+ Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
503
+ adds the proper fields in a config.
504
+
505
+ Args:
506
+ obj (`Any`): The object for which to save the module files.
507
+ folder (`str` or `os.PathLike`): The folder where to save.
508
+ config (`PretrainedConfig` or dictionary, `optional`):
509
+ A config in which to register the auto_map corresponding to this custom object.
510
+
511
+ Returns:
512
+ `List[str]`: The list of files saved.
513
+ """
514
+ if obj.__module__ == "__main__":
515
+ logger.warning(
516
+ f"We can't save the code defining {obj} in {folder} as it's been defined in __main__. You should put "
517
+ "this code in a separate module so we can include it in the saved folder and make it easier to share via "
518
+ "the Hub."
519
+ )
520
+ return
521
+
522
+ def _set_auto_map_in_config(_config):
523
+ module_name = obj.__class__.__module__
524
+ last_module = module_name.split(".")[-1]
525
+ full_name = f"{last_module}.{obj.__class__.__name__}"
526
+ # Special handling for tokenizers
527
+ if "Tokenizer" in full_name:
528
+ slow_tokenizer_class = None
529
+ fast_tokenizer_class = None
530
+ if obj.__class__.__name__.endswith("Fast"):
531
+ # Fast tokenizer: we have the fast tokenizer class and we may have the slow one has an attribute.
532
+ fast_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
533
+ if getattr(obj, "slow_tokenizer_class", None) is not None:
534
+ slow_tokenizer = getattr(obj, "slow_tokenizer_class")
535
+ slow_tok_module_name = slow_tokenizer.__module__
536
+ last_slow_tok_module = slow_tok_module_name.split(".")[-1]
537
+ slow_tokenizer_class = f"{last_slow_tok_module}.{slow_tokenizer.__name__}"
538
+ else:
539
+ # Slow tokenizer: no way to have the fast class
540
+ slow_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
541
+
542
+ full_name = (slow_tokenizer_class, fast_tokenizer_class)
543
+
544
+ if isinstance(_config, dict):
545
+ auto_map = _config.get("auto_map", {})
546
+ auto_map[obj._auto_class] = full_name
547
+ _config["auto_map"] = auto_map
548
+ elif getattr(_config, "auto_map", None) is not None:
549
+ _config.auto_map[obj._auto_class] = full_name
550
+ else:
551
+ _config.auto_map = {obj._auto_class: full_name}
552
+
553
+ # Add object class to the config auto_map
554
+ if isinstance(config, (list, tuple)):
555
+ for cfg in config:
556
+ _set_auto_map_in_config(cfg)
557
+ elif config is not None:
558
+ _set_auto_map_in_config(config)
559
+
560
+ result = []
561
+ # Copy module file to the output folder.
562
+ object_file = sys.modules[obj.__module__].__file__
563
+ dest_file = Path(folder) / (Path(object_file).name)
564
+ shutil.copy(object_file, dest_file)
565
+ result.append(dest_file)
566
+
567
+ # Gather all relative imports recursively and make sure they are copied as well.
568
+ for needed_file in get_relative_import_files(object_file):
569
+ dest_file = Path(folder) / (Path(needed_file).name)
570
+ shutil.copy(needed_file, dest_file)
571
+ result.append(dest_file)
572
+
573
+ return result
574
+
575
+
576
+ def _raise_timeout_error(signum, frame):
577
+ raise ValueError(
578
+ "Loading this model requires you to execute custom code contained in the model repository on your local"
579
+ "machine. Please set the option `trust_remote_code=True` to permit loading of this model."
580
+ )
581
+
582
+
583
+ TIME_OUT_REMOTE_CODE = 15
584
+
585
+
586
+ def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
587
+ if trust_remote_code is None:
588
+ if has_local_code:
589
+ trust_remote_code = False
590
+ elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
591
+ try:
592
+ signal.signal(signal.SIGALRM, _raise_timeout_error)
593
+ signal.alarm(TIME_OUT_REMOTE_CODE)
594
+ while trust_remote_code is None:
595
+ answer = input(
596
+ f"The repository for {model_name} contains custom code which must be executed to correctly"
597
+ f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
598
+ f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
599
+ f"Do you wish to run the custom code? [y/N] "
600
+ )
601
+ if answer.lower() in ["yes", "y", "1"]:
602
+ trust_remote_code = True
603
+ elif answer.lower() in ["no", "n", "0", ""]:
604
+ trust_remote_code = False
605
+ signal.alarm(0)
606
+ except Exception:
607
+ # OS which does not support signal.SIGALRM
608
+ raise ValueError(
609
+ f"The repository for {model_name} contains custom code which must be executed to correctly"
610
+ f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
611
+ f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
612
+ )
613
+ elif has_remote_code:
614
+ # For the CI which puts the timeout at 0
615
+ _raise_timeout_error(None, None)
616
+
617
+ if has_remote_code and not has_local_code and not trust_remote_code:
618
+ raise ValueError(
619
+ f"Loading {model_name} requires you to execute the configuration file in that"
620
+ " repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
621
+ " set the option `trust_remote_code=True` to remove this error."
622
+ )
623
+
624
+ return trust_remote_code