liuyizhang
commited on
Commit
•
1ce5e18
1
Parent(s):
77de6b0
add transformers_4_35_0
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +2 -0
- kosmos_utils.py +7 -2
- transformers_4_35_0/__init__.py +0 -0
- transformers_4_35_0/activations.py +251 -0
- transformers_4_35_0/activations_tf.py +134 -0
- transformers_4_35_0/audio_utils.py +721 -0
- transformers_4_35_0/benchmark/__init__.py +0 -0
- transformers_4_35_0/benchmark/benchmark.py +271 -0
- transformers_4_35_0/benchmark/benchmark_args.py +114 -0
- transformers_4_35_0/benchmark/benchmark_args_tf.py +136 -0
- transformers_4_35_0/benchmark/benchmark_args_utils.py +166 -0
- transformers_4_35_0/benchmark/benchmark_tf.py +303 -0
- transformers_4_35_0/benchmark/benchmark_utils.py +914 -0
- transformers_4_35_0/commands/__init__.py +27 -0
- transformers_4_35_0/commands/add_new_model.py +259 -0
- transformers_4_35_0/commands/add_new_model_like.py +1763 -0
- transformers_4_35_0/commands/convert.py +184 -0
- transformers_4_35_0/commands/download.py +56 -0
- transformers_4_35_0/commands/env.py +143 -0
- transformers_4_35_0/commands/lfs.py +226 -0
- transformers_4_35_0/commands/pt_to_tf.py +425 -0
- transformers_4_35_0/commands/run.py +110 -0
- transformers_4_35_0/commands/serving.py +228 -0
- transformers_4_35_0/commands/train.py +158 -0
- transformers_4_35_0/commands/transformers_cli.py +59 -0
- transformers_4_35_0/commands/user.py +197 -0
- transformers_4_35_0/configuration_utils.py +1075 -0
- transformers_4_35_0/convert_graph_to_onnx.py +569 -0
- transformers_4_35_0/convert_pytorch_checkpoint_to_tf2.py +492 -0
- transformers_4_35_0/convert_slow_tokenizer.py +1318 -0
- transformers_4_35_0/convert_slow_tokenizers_checkpoints_to_fast.py +126 -0
- transformers_4_35_0/convert_tf_hub_seq_to_seq_bert_to_pytorch.py +88 -0
- transformers_4_35_0/data/__init__.py +44 -0
- transformers_4_35_0/data/data_collator.py +1535 -0
- transformers_4_35_0/data/datasets/__init__.py +23 -0
- transformers_4_35_0/data/datasets/glue.py +161 -0
- transformers_4_35_0/data/datasets/language_modeling.py +530 -0
- transformers_4_35_0/data/datasets/squad.py +229 -0
- transformers_4_35_0/data/metrics/__init__.py +98 -0
- transformers_4_35_0/data/metrics/squad_metrics.py +780 -0
- transformers_4_35_0/data/processors/__init__.py +18 -0
- transformers_4_35_0/data/processors/glue.py +643 -0
- transformers_4_35_0/data/processors/squad.py +845 -0
- transformers_4_35_0/data/processors/utils.py +349 -0
- transformers_4_35_0/data/processors/xnli.py +97 -0
- transformers_4_35_0/debug_utils.py +346 -0
- transformers_4_35_0/deepspeed.py +40 -0
- transformers_4_35_0/dependency_versions_check.py +63 -0
- transformers_4_35_0/dependency_versions_table.py +90 -0
- transformers_4_35_0/dynamic_module_utils.py +624 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
*.pyc
|
kosmos_utils.py
CHANGED
@@ -1,11 +1,16 @@
|
|
1 |
import random
|
2 |
import numpy as np
|
3 |
-
import os
|
4 |
import requests
|
5 |
import torch
|
6 |
import torchvision.transforms as torchvision_T
|
7 |
from PIL import Image
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
9 |
import cv2
|
10 |
import ast
|
11 |
|
|
|
1 |
import random
|
2 |
import numpy as np
|
3 |
+
import os,sys
|
4 |
import requests
|
5 |
import torch
|
6 |
import torchvision.transforms as torchvision_T
|
7 |
from PIL import Image
|
8 |
+
|
9 |
+
# from transformers import AutoProcessor, AutoModelForVision2Seq
|
10 |
+
import subprocess, io, os, sys, time
|
11 |
+
sys.path.insert(0, './transformers_4_35_0')
|
12 |
+
from transformers_4_35_0 import AutoProcessor, AutoModelForVision2Seq
|
13 |
+
|
14 |
import cv2
|
15 |
import ast
|
16 |
|
transformers_4_35_0/__init__.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
transformers_4_35_0/activations.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
from collections import OrderedDict
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from packaging import version
|
20 |
+
from torch import Tensor, nn
|
21 |
+
|
22 |
+
from .utils import logging
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
class PytorchGELUTanh(nn.Module):
|
29 |
+
"""
|
30 |
+
A fast C implementation of the tanh approximation of the GeLU activation function. See
|
31 |
+
https://arxiv.org/abs/1606.08415.
|
32 |
+
|
33 |
+
This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
|
34 |
+
match due to rounding errors.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self):
|
38 |
+
super().__init__()
|
39 |
+
if version.parse(torch.__version__) < version.parse("1.12.0"):
|
40 |
+
raise ImportError(
|
41 |
+
f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
|
42 |
+
"PytorchGELUTanh. Please upgrade torch."
|
43 |
+
)
|
44 |
+
|
45 |
+
def forward(self, input: Tensor) -> Tensor:
|
46 |
+
return nn.functional.gelu(input, approximate="tanh")
|
47 |
+
|
48 |
+
|
49 |
+
class NewGELUActivation(nn.Module):
|
50 |
+
"""
|
51 |
+
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
52 |
+
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
53 |
+
"""
|
54 |
+
|
55 |
+
def forward(self, input: Tensor) -> Tensor:
|
56 |
+
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
|
57 |
+
|
58 |
+
|
59 |
+
class GELUActivation(nn.Module):
|
60 |
+
"""
|
61 |
+
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
|
62 |
+
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
|
63 |
+
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
|
64 |
+
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, use_gelu_python: bool = False):
|
68 |
+
super().__init__()
|
69 |
+
if use_gelu_python:
|
70 |
+
self.act = self._gelu_python
|
71 |
+
else:
|
72 |
+
self.act = nn.functional.gelu
|
73 |
+
|
74 |
+
def _gelu_python(self, input: Tensor) -> Tensor:
|
75 |
+
return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
|
76 |
+
|
77 |
+
def forward(self, input: Tensor) -> Tensor:
|
78 |
+
return self.act(input)
|
79 |
+
|
80 |
+
|
81 |
+
class FastGELUActivation(nn.Module):
|
82 |
+
"""
|
83 |
+
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
|
84 |
+
"""
|
85 |
+
|
86 |
+
def forward(self, input: Tensor) -> Tensor:
|
87 |
+
return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
|
88 |
+
|
89 |
+
|
90 |
+
class QuickGELUActivation(nn.Module):
|
91 |
+
"""
|
92 |
+
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
|
93 |
+
"""
|
94 |
+
|
95 |
+
def forward(self, input: Tensor) -> Tensor:
|
96 |
+
return input * torch.sigmoid(1.702 * input)
|
97 |
+
|
98 |
+
|
99 |
+
class ClippedGELUActivation(nn.Module):
|
100 |
+
"""
|
101 |
+
Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
|
102 |
+
it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
|
103 |
+
https://arxiv.org/abs/2004.09602.
|
104 |
+
|
105 |
+
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
|
106 |
+
initially created.
|
107 |
+
|
108 |
+
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
|
109 |
+
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(self, min: float, max: float):
|
113 |
+
if min > max:
|
114 |
+
raise ValueError(f"min should be < max (got min: {min}, max: {max})")
|
115 |
+
|
116 |
+
super().__init__()
|
117 |
+
self.min = min
|
118 |
+
self.max = max
|
119 |
+
|
120 |
+
def forward(self, x: Tensor) -> Tensor:
|
121 |
+
return torch.clip(gelu(x), self.min, self.max)
|
122 |
+
|
123 |
+
|
124 |
+
class AccurateGELUActivation(nn.Module):
|
125 |
+
"""
|
126 |
+
Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
|
127 |
+
https://github.com/hendrycks/GELUs
|
128 |
+
|
129 |
+
Implemented along with MEGA (Moving Average Equipped Gated Attention)
|
130 |
+
"""
|
131 |
+
|
132 |
+
def __init__(self):
|
133 |
+
super().__init__()
|
134 |
+
self.precomputed_constant = math.sqrt(2 / math.pi)
|
135 |
+
|
136 |
+
def forward(self, input: Tensor) -> Tensor:
|
137 |
+
return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
|
138 |
+
|
139 |
+
|
140 |
+
class SiLUActivation(nn.Module):
|
141 |
+
"""
|
142 |
+
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
|
143 |
+
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
|
144 |
+
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
|
145 |
+
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
|
146 |
+
later.
|
147 |
+
"""
|
148 |
+
|
149 |
+
def forward(self, input: Tensor) -> Tensor:
|
150 |
+
return nn.functional.silu(input)
|
151 |
+
|
152 |
+
|
153 |
+
class MishActivation(nn.Module):
|
154 |
+
"""
|
155 |
+
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
|
156 |
+
visit the official repository for the paper: https://github.com/digantamisra98/Mish
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(self):
|
160 |
+
super().__init__()
|
161 |
+
if version.parse(torch.__version__) < version.parse("1.9.0"):
|
162 |
+
self.act = self._mish_python
|
163 |
+
else:
|
164 |
+
self.act = nn.functional.mish
|
165 |
+
|
166 |
+
def _mish_python(self, input: Tensor) -> Tensor:
|
167 |
+
return input * torch.tanh(nn.functional.softplus(input))
|
168 |
+
|
169 |
+
def forward(self, input: Tensor) -> Tensor:
|
170 |
+
return self.act(input)
|
171 |
+
|
172 |
+
|
173 |
+
class LinearActivation(nn.Module):
|
174 |
+
"""
|
175 |
+
Applies the linear activation function, i.e. forwarding input directly to output.
|
176 |
+
"""
|
177 |
+
|
178 |
+
def forward(self, input: Tensor) -> Tensor:
|
179 |
+
return input
|
180 |
+
|
181 |
+
|
182 |
+
class LaplaceActivation(nn.Module):
|
183 |
+
"""
|
184 |
+
Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
|
185 |
+
https://arxiv.org/abs/2209.10655
|
186 |
+
|
187 |
+
Inspired by squared relu, but with bounded range and gradient for better stability
|
188 |
+
"""
|
189 |
+
|
190 |
+
def forward(self, input, mu=0.707107, sigma=0.282095):
|
191 |
+
input = (input - mu).div(sigma * math.sqrt(2.0))
|
192 |
+
return 0.5 * (1.0 + torch.erf(input))
|
193 |
+
|
194 |
+
|
195 |
+
class ReLUSquaredActivation(nn.Module):
|
196 |
+
"""
|
197 |
+
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
198 |
+
"""
|
199 |
+
|
200 |
+
def forward(self, input):
|
201 |
+
relu_applied = nn.functional.relu(input)
|
202 |
+
squared = torch.square(relu_applied)
|
203 |
+
return squared
|
204 |
+
|
205 |
+
|
206 |
+
class ClassInstantier(OrderedDict):
|
207 |
+
def __getitem__(self, key):
|
208 |
+
content = super().__getitem__(key)
|
209 |
+
cls, kwargs = content if isinstance(content, tuple) else (content, {})
|
210 |
+
return cls(**kwargs)
|
211 |
+
|
212 |
+
|
213 |
+
ACT2CLS = {
|
214 |
+
"gelu": GELUActivation,
|
215 |
+
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
|
216 |
+
"gelu_fast": FastGELUActivation,
|
217 |
+
"gelu_new": NewGELUActivation,
|
218 |
+
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
|
219 |
+
"gelu_pytorch_tanh": PytorchGELUTanh,
|
220 |
+
"gelu_accurate": AccurateGELUActivation,
|
221 |
+
"laplace": LaplaceActivation,
|
222 |
+
"linear": LinearActivation,
|
223 |
+
"mish": MishActivation,
|
224 |
+
"quick_gelu": QuickGELUActivation,
|
225 |
+
"relu": nn.ReLU,
|
226 |
+
"relu2": ReLUSquaredActivation,
|
227 |
+
"relu6": nn.ReLU6,
|
228 |
+
"sigmoid": nn.Sigmoid,
|
229 |
+
"silu": SiLUActivation,
|
230 |
+
"swish": SiLUActivation,
|
231 |
+
"tanh": nn.Tanh,
|
232 |
+
}
|
233 |
+
ACT2FN = ClassInstantier(ACT2CLS)
|
234 |
+
|
235 |
+
|
236 |
+
def get_activation(activation_string):
|
237 |
+
if activation_string in ACT2FN:
|
238 |
+
return ACT2FN[activation_string]
|
239 |
+
else:
|
240 |
+
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
|
241 |
+
|
242 |
+
|
243 |
+
# For backwards compatibility with: from activations import gelu_python
|
244 |
+
gelu_python = get_activation("gelu_python")
|
245 |
+
gelu_new = get_activation("gelu_new")
|
246 |
+
gelu = get_activation("gelu")
|
247 |
+
gelu_fast = get_activation("gelu_fast")
|
248 |
+
quick_gelu = get_activation("quick_gelu")
|
249 |
+
silu = get_activation("silu")
|
250 |
+
mish = get_activation("mish")
|
251 |
+
linear_act = get_activation("linear")
|
transformers_4_35_0/activations_tf.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
|
17 |
+
import tensorflow as tf
|
18 |
+
from packaging import version
|
19 |
+
|
20 |
+
|
21 |
+
def _gelu(x):
|
22 |
+
"""
|
23 |
+
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
|
24 |
+
initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
25 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
|
26 |
+
https://arxiv.org/abs/1606.08415
|
27 |
+
"""
|
28 |
+
x = tf.convert_to_tensor(x)
|
29 |
+
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))
|
30 |
+
|
31 |
+
return x * cdf
|
32 |
+
|
33 |
+
|
34 |
+
def _gelu_new(x):
|
35 |
+
"""
|
36 |
+
Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841
|
37 |
+
|
38 |
+
Args:
|
39 |
+
x: float Tensor to perform activation
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
`x` with the GELU activation applied.
|
43 |
+
"""
|
44 |
+
x = tf.convert_to_tensor(x)
|
45 |
+
pi = tf.cast(math.pi, x.dtype)
|
46 |
+
coeff = tf.cast(0.044715, x.dtype)
|
47 |
+
cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3))))
|
48 |
+
|
49 |
+
return x * cdf
|
50 |
+
|
51 |
+
|
52 |
+
def mish(x):
|
53 |
+
x = tf.convert_to_tensor(x)
|
54 |
+
|
55 |
+
return x * tf.tanh(tf.math.softplus(x))
|
56 |
+
|
57 |
+
|
58 |
+
def gelu_fast(x):
|
59 |
+
x = tf.convert_to_tensor(x)
|
60 |
+
coeff1 = tf.cast(0.044715, x.dtype)
|
61 |
+
coeff2 = tf.cast(0.7978845608, x.dtype)
|
62 |
+
|
63 |
+
return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x)))
|
64 |
+
|
65 |
+
|
66 |
+
def quick_gelu(x):
|
67 |
+
x = tf.convert_to_tensor(x)
|
68 |
+
coeff = tf.cast(1.702, x.dtype)
|
69 |
+
return x * tf.math.sigmoid(coeff * x)
|
70 |
+
|
71 |
+
|
72 |
+
def gelu_10(x):
|
73 |
+
"""
|
74 |
+
Clip the range of possible GeLU outputs between [-10, 10]. This is especially useful for quantization purpose, as
|
75 |
+
it allows mapping 2 negatives values in the GeLU spectrum. For more information on this trick, please refer to
|
76 |
+
https://arxiv.org/abs/2004.09602
|
77 |
+
|
78 |
+
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
|
79 |
+
initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
80 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
|
81 |
+
https://arxiv.org/abs/1606.08415 :param x: :return:
|
82 |
+
"""
|
83 |
+
return tf.clip_by_value(_gelu(x), -10, 10)
|
84 |
+
|
85 |
+
|
86 |
+
def glu(x, axis=-1):
|
87 |
+
"""
|
88 |
+
Gated Linear Unit. Implementation as defined in the original paper (see https://arxiv.org/abs/1612.08083), where
|
89 |
+
the input `x` is split in two halves across a dimension (`axis`), A and B, returning A * sigmoid(B).
|
90 |
+
|
91 |
+
Args:
|
92 |
+
`x`: float Tensor to perform activation
|
93 |
+
`axis`: dimension across which `x` be split in half
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
`x` with the GLU activation applied (with its size halved across the dimension `axis`).
|
97 |
+
"""
|
98 |
+
a, b = tf.split(x, 2, axis=axis)
|
99 |
+
return a * tf.math.sigmoid(b)
|
100 |
+
|
101 |
+
|
102 |
+
if version.parse(tf.version.VERSION) >= version.parse("2.4"):
|
103 |
+
|
104 |
+
def approximate_gelu_wrap(x):
|
105 |
+
return tf.keras.activations.gelu(x, approximate=True)
|
106 |
+
|
107 |
+
gelu = tf.keras.activations.gelu
|
108 |
+
gelu_new = approximate_gelu_wrap
|
109 |
+
else:
|
110 |
+
gelu = _gelu
|
111 |
+
gelu_new = _gelu_new
|
112 |
+
|
113 |
+
|
114 |
+
ACT2FN = {
|
115 |
+
"gelu": gelu,
|
116 |
+
"gelu_10": gelu_10,
|
117 |
+
"gelu_fast": gelu_fast,
|
118 |
+
"gelu_new": gelu_new,
|
119 |
+
"glu": glu,
|
120 |
+
"mish": mish,
|
121 |
+
"quick_gelu": quick_gelu,
|
122 |
+
"relu": tf.keras.activations.relu,
|
123 |
+
"sigmoid": tf.keras.activations.sigmoid,
|
124 |
+
"silu": tf.keras.activations.swish,
|
125 |
+
"swish": tf.keras.activations.swish,
|
126 |
+
"tanh": tf.keras.activations.tanh,
|
127 |
+
}
|
128 |
+
|
129 |
+
|
130 |
+
def get_tf_activation(activation_string):
|
131 |
+
if activation_string in ACT2FN:
|
132 |
+
return ACT2FN[activation_string]
|
133 |
+
else:
|
134 |
+
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
|
transformers_4_35_0/audio_utils.py
ADDED
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team and the librosa & torchaudio authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks
|
17 |
+
and remove unnecessary dependencies.
|
18 |
+
"""
|
19 |
+
import warnings
|
20 |
+
from typing import Optional, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
|
25 |
+
def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
|
26 |
+
"""
|
27 |
+
Convert frequency from hertz to mels.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
freq (`float` or `np.ndarray`):
|
31 |
+
The frequency, or multiple frequencies, in hertz (Hz).
|
32 |
+
mel_scale (`str`, *optional*, defaults to `"htk"`):
|
33 |
+
The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
`float` or `np.ndarray`: The frequencies on the mel scale.
|
37 |
+
"""
|
38 |
+
|
39 |
+
if mel_scale not in ["slaney", "htk", "kaldi"]:
|
40 |
+
raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
|
41 |
+
|
42 |
+
if mel_scale == "htk":
|
43 |
+
return 2595.0 * np.log10(1.0 + (freq / 700.0))
|
44 |
+
elif mel_scale == "kaldi":
|
45 |
+
return 1127.0 * np.log(1.0 + (freq / 700.0))
|
46 |
+
|
47 |
+
min_log_hertz = 1000.0
|
48 |
+
min_log_mel = 15.0
|
49 |
+
logstep = 27.0 / np.log(6.4)
|
50 |
+
mels = 3.0 * freq / 200.0
|
51 |
+
|
52 |
+
if isinstance(freq, np.ndarray):
|
53 |
+
log_region = freq >= min_log_hertz
|
54 |
+
mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
|
55 |
+
elif freq >= min_log_hertz:
|
56 |
+
mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
|
57 |
+
|
58 |
+
return mels
|
59 |
+
|
60 |
+
|
61 |
+
def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
|
62 |
+
"""
|
63 |
+
Convert frequency from mels to hertz.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
mels (`float` or `np.ndarray`):
|
67 |
+
The frequency, or multiple frequencies, in mels.
|
68 |
+
mel_scale (`str`, *optional*, `"htk"`):
|
69 |
+
The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
`float` or `np.ndarray`: The frequencies in hertz.
|
73 |
+
"""
|
74 |
+
|
75 |
+
if mel_scale not in ["slaney", "htk", "kaldi"]:
|
76 |
+
raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
|
77 |
+
|
78 |
+
if mel_scale == "htk":
|
79 |
+
return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
|
80 |
+
elif mel_scale == "kaldi":
|
81 |
+
return 700.0 * (np.exp(mels / 1127.0) - 1.0)
|
82 |
+
|
83 |
+
min_log_hertz = 1000.0
|
84 |
+
min_log_mel = 15.0
|
85 |
+
logstep = np.log(6.4) / 27.0
|
86 |
+
freq = 200.0 * mels / 3.0
|
87 |
+
|
88 |
+
if isinstance(mels, np.ndarray):
|
89 |
+
log_region = mels >= min_log_mel
|
90 |
+
freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel))
|
91 |
+
elif mels >= min_log_mel:
|
92 |
+
freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel))
|
93 |
+
|
94 |
+
return freq
|
95 |
+
|
96 |
+
|
97 |
+
def _create_triangular_filter_bank(fft_freqs: np.ndarray, filter_freqs: np.ndarray) -> np.ndarray:
|
98 |
+
"""
|
99 |
+
Creates a triangular filter bank.
|
100 |
+
|
101 |
+
Adapted from *torchaudio* and *librosa*.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
fft_freqs (`np.ndarray` of shape `(num_frequency_bins,)`):
|
105 |
+
Discrete frequencies of the FFT bins in Hz.
|
106 |
+
filter_freqs (`np.ndarray` of shape `(num_mel_filters,)`):
|
107 |
+
Center frequencies of the triangular filters to create, in Hz.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
`np.ndarray` of shape `(num_frequency_bins, num_mel_filters)`
|
111 |
+
"""
|
112 |
+
filter_diff = np.diff(filter_freqs)
|
113 |
+
slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
|
114 |
+
down_slopes = -slopes[:, :-2] / filter_diff[:-1]
|
115 |
+
up_slopes = slopes[:, 2:] / filter_diff[1:]
|
116 |
+
return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
|
117 |
+
|
118 |
+
|
119 |
+
def mel_filter_bank(
|
120 |
+
num_frequency_bins: int,
|
121 |
+
num_mel_filters: int,
|
122 |
+
min_frequency: float,
|
123 |
+
max_frequency: float,
|
124 |
+
sampling_rate: int,
|
125 |
+
norm: Optional[str] = None,
|
126 |
+
mel_scale: str = "htk",
|
127 |
+
triangularize_in_mel_space: bool = False,
|
128 |
+
) -> np.ndarray:
|
129 |
+
"""
|
130 |
+
Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and
|
131 |
+
various implementation exist, which differ in the number of filters, the shape of the filters, the way the filters
|
132 |
+
are spaced, the bandwidth of the filters, and the manner in which the spectrum is warped. The goal of these
|
133 |
+
features is to approximate the non-linear human perception of the variation in pitch with respect to the frequency.
|
134 |
+
|
135 |
+
Different banks of mel filters were introduced in the literature. The following variations are supported:
|
136 |
+
|
137 |
+
- MFCC FB-20: introduced in 1980 by Davis and Mermelstein, it assumes a sampling frequency of 10 kHz and a speech
|
138 |
+
bandwidth of `[0, 4600]` Hz.
|
139 |
+
- MFCC FB-24 HTK: from the Cambridge HMM Toolkit (HTK) (1995) uses a filter bank of 24 filters for a speech
|
140 |
+
bandwidth of `[0, 8000]` Hz. This assumes sampling rate ≥ 16 kHz.
|
141 |
+
- MFCC FB-40: from the Auditory Toolbox for MATLAB written by Slaney in 1998, assumes a sampling rate of 16 kHz and
|
142 |
+
speech bandwidth of `[133, 6854]` Hz. This version also includes area normalization.
|
143 |
+
- HFCC-E FB-29 (Human Factor Cepstral Coefficients) of Skowronski and Harris (2004), assumes a sampling rate of
|
144 |
+
12.5 kHz and speech bandwidth of `[0, 6250]` Hz.
|
145 |
+
|
146 |
+
This code is adapted from *torchaudio* and *librosa*. Note that the default parameters of torchaudio's
|
147 |
+
`melscale_fbanks` implement the `"htk"` filters while librosa uses the `"slaney"` implementation.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
num_frequency_bins (`int`):
|
151 |
+
Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
|
152 |
+
num_mel_filters (`int`):
|
153 |
+
Number of mel filters to generate.
|
154 |
+
min_frequency (`float`):
|
155 |
+
Lowest frequency of interest in Hz.
|
156 |
+
max_frequency (`float`):
|
157 |
+
Highest frequency of interest in Hz. This should not exceed `sampling_rate / 2`.
|
158 |
+
sampling_rate (`int`):
|
159 |
+
Sample rate of the audio waveform.
|
160 |
+
norm (`str`, *optional*):
|
161 |
+
If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization).
|
162 |
+
mel_scale (`str`, *optional*, defaults to `"htk"`):
|
163 |
+
The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
|
164 |
+
triangularize_in_mel_space (`bool`, *optional*, defaults to `False`):
|
165 |
+
If this option is enabled, the triangular filter is applied in mel space rather than frequency space. This
|
166 |
+
should be set to `true` in order to get the same results as `torchaudio` when computing mel filters.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
`np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a
|
170 |
+
projection matrix to go from a spectrogram to a mel spectrogram.
|
171 |
+
"""
|
172 |
+
if norm is not None and norm != "slaney":
|
173 |
+
raise ValueError('norm must be one of None or "slaney"')
|
174 |
+
|
175 |
+
# center points of the triangular mel filters
|
176 |
+
mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
|
177 |
+
mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
|
178 |
+
mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
|
179 |
+
filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
|
180 |
+
|
181 |
+
if triangularize_in_mel_space:
|
182 |
+
# frequencies of FFT bins in Hz, but filters triangularized in mel space
|
183 |
+
fft_bin_width = sampling_rate / (num_frequency_bins * 2)
|
184 |
+
fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
|
185 |
+
filter_freqs = mel_freqs
|
186 |
+
else:
|
187 |
+
# frequencies of FFT bins in Hz
|
188 |
+
fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
|
189 |
+
|
190 |
+
mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
|
191 |
+
|
192 |
+
if norm is not None and norm == "slaney":
|
193 |
+
# Slaney-style mel is scaled to be approx constant energy per channel
|
194 |
+
enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters])
|
195 |
+
mel_filters *= np.expand_dims(enorm, 0)
|
196 |
+
|
197 |
+
if (mel_filters.max(axis=0) == 0.0).any():
|
198 |
+
warnings.warn(
|
199 |
+
"At least one mel filter has all zero values. "
|
200 |
+
f"The value for `num_mel_filters` ({num_mel_filters}) may be set too high. "
|
201 |
+
f"Or, the value for `num_frequency_bins` ({num_frequency_bins}) may be set too low."
|
202 |
+
)
|
203 |
+
|
204 |
+
return mel_filters
|
205 |
+
|
206 |
+
|
207 |
+
def optimal_fft_length(window_length: int) -> int:
|
208 |
+
"""
|
209 |
+
Finds the best FFT input size for a given `window_length`. This function takes a given window length and, if not
|
210 |
+
already a power of two, rounds it up to the next power or two.
|
211 |
+
|
212 |
+
The FFT algorithm works fastest when the length of the input is a power of two, which may be larger than the size
|
213 |
+
of the window or analysis frame. For example, if the window is 400 samples, using an FFT input size of 512 samples
|
214 |
+
is more optimal than an FFT size of 400 samples. Using a larger FFT size does not affect the detected frequencies,
|
215 |
+
it simply gives a higher frequency resolution (i.e. the frequency bins are smaller).
|
216 |
+
"""
|
217 |
+
return 2 ** int(np.ceil(np.log2(window_length)))
|
218 |
+
|
219 |
+
|
220 |
+
def window_function(
|
221 |
+
window_length: int,
|
222 |
+
name: str = "hann",
|
223 |
+
periodic: bool = True,
|
224 |
+
frame_length: Optional[int] = None,
|
225 |
+
center: bool = True,
|
226 |
+
) -> np.ndarray:
|
227 |
+
"""
|
228 |
+
Returns an array containing the specified window. This window is intended to be used with `stft`.
|
229 |
+
|
230 |
+
The following window types are supported:
|
231 |
+
|
232 |
+
- `"boxcar"`: a rectangular window
|
233 |
+
- `"hamming"`: the Hamming window
|
234 |
+
- `"hann"`: the Hann window
|
235 |
+
- `"povey"`: the Povey window
|
236 |
+
|
237 |
+
Args:
|
238 |
+
window_length (`int`):
|
239 |
+
The length of the window in samples.
|
240 |
+
name (`str`, *optional*, defaults to `"hann"`):
|
241 |
+
The name of the window function.
|
242 |
+
periodic (`bool`, *optional*, defaults to `True`):
|
243 |
+
Whether the window is periodic or symmetric.
|
244 |
+
frame_length (`int`, *optional*):
|
245 |
+
The length of the analysis frames in samples. Provide a value for `frame_length` if the window is smaller
|
246 |
+
than the frame length, so that it will be zero-padded.
|
247 |
+
center (`bool`, *optional*, defaults to `True`):
|
248 |
+
Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided.
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
`np.ndarray` of shape `(window_length,)` or `(frame_length,)` containing the window.
|
252 |
+
"""
|
253 |
+
length = window_length + 1 if periodic else window_length
|
254 |
+
|
255 |
+
if name == "boxcar":
|
256 |
+
window = np.ones(length)
|
257 |
+
elif name in ["hamming", "hamming_window"]:
|
258 |
+
window = np.hamming(length)
|
259 |
+
elif name in ["hann", "hann_window"]:
|
260 |
+
window = np.hanning(length)
|
261 |
+
elif name in ["povey"]:
|
262 |
+
window = np.power(np.hanning(length), 0.85)
|
263 |
+
else:
|
264 |
+
raise ValueError(f"Unknown window function '{name}'")
|
265 |
+
|
266 |
+
if periodic:
|
267 |
+
window = window[:-1]
|
268 |
+
|
269 |
+
if frame_length is None:
|
270 |
+
return window
|
271 |
+
|
272 |
+
if window_length > frame_length:
|
273 |
+
raise ValueError(
|
274 |
+
f"Length of the window ({window_length}) may not be larger than frame_length ({frame_length})"
|
275 |
+
)
|
276 |
+
|
277 |
+
padded_window = np.zeros(frame_length)
|
278 |
+
offset = (frame_length - window_length) // 2 if center else 0
|
279 |
+
padded_window[offset : offset + window_length] = window
|
280 |
+
return padded_window
|
281 |
+
|
282 |
+
|
283 |
+
# TODO This method does not support batching yet as we are mainly focused on inference.
|
284 |
+
def spectrogram(
|
285 |
+
waveform: np.ndarray,
|
286 |
+
window: np.ndarray,
|
287 |
+
frame_length: int,
|
288 |
+
hop_length: int,
|
289 |
+
fft_length: Optional[int] = None,
|
290 |
+
power: Optional[float] = 1.0,
|
291 |
+
center: bool = True,
|
292 |
+
pad_mode: str = "reflect",
|
293 |
+
onesided: bool = True,
|
294 |
+
preemphasis: Optional[float] = None,
|
295 |
+
mel_filters: Optional[np.ndarray] = None,
|
296 |
+
mel_floor: float = 1e-10,
|
297 |
+
log_mel: Optional[str] = None,
|
298 |
+
reference: float = 1.0,
|
299 |
+
min_value: float = 1e-10,
|
300 |
+
db_range: Optional[float] = None,
|
301 |
+
remove_dc_offset: Optional[bool] = None,
|
302 |
+
dtype: np.dtype = np.float32,
|
303 |
+
) -> np.ndarray:
|
304 |
+
"""
|
305 |
+
Calculates a spectrogram over one waveform using the Short-Time Fourier Transform.
|
306 |
+
|
307 |
+
This function can create the following kinds of spectrograms:
|
308 |
+
|
309 |
+
- amplitude spectrogram (`power = 1.0`)
|
310 |
+
- power spectrogram (`power = 2.0`)
|
311 |
+
- complex-valued spectrogram (`power = None`)
|
312 |
+
- log spectrogram (use `log_mel` argument)
|
313 |
+
- mel spectrogram (provide `mel_filters`)
|
314 |
+
- log-mel spectrogram (provide `mel_filters` and `log_mel`)
|
315 |
+
|
316 |
+
How this works:
|
317 |
+
|
318 |
+
1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
|
319 |
+
- hop_length` samples.
|
320 |
+
2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
|
321 |
+
3. The DFT is taken of each windowed frame.
|
322 |
+
4. The results are stacked into a spectrogram.
|
323 |
+
|
324 |
+
We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
|
325 |
+
|
326 |
+
- The analysis frame. This is the size of the time slices that the input waveform is split into.
|
327 |
+
- The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
|
328 |
+
- The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
|
329 |
+
|
330 |
+
In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
|
331 |
+
padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
|
332 |
+
typically the next power of two.
|
333 |
+
|
334 |
+
Note: This function is not optimized for speed yet. It should be mostly compatible with `librosa.stft` and
|
335 |
+
`torchaudio.functional.transforms.Spectrogram`, although it is more flexible due to the different ways spectrograms
|
336 |
+
can be constructed.
|
337 |
+
|
338 |
+
Args:
|
339 |
+
waveform (`np.ndarray` of shape `(length,)`):
|
340 |
+
The input waveform. This must be a single real-valued, mono waveform.
|
341 |
+
window (`np.ndarray` of shape `(frame_length,)`):
|
342 |
+
The windowing function to apply, including zero-padding if necessary. The actual window length may be
|
343 |
+
shorter than `frame_length`, but we're assuming the array has already been zero-padded.
|
344 |
+
frame_length (`int`):
|
345 |
+
The length of the analysis frames in samples. With librosa this is always equal to `fft_length` but we also
|
346 |
+
allow smaller sizes.
|
347 |
+
hop_length (`int`):
|
348 |
+
The stride between successive analysis frames in samples.
|
349 |
+
fft_length (`int`, *optional*):
|
350 |
+
The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have.
|
351 |
+
For optimal speed, this should be a power of two. If `None`, uses `frame_length`.
|
352 |
+
power (`float`, *optional*, defaults to 1.0):
|
353 |
+
If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `None`, returns
|
354 |
+
complex numbers.
|
355 |
+
center (`bool`, *optional*, defaults to `True`):
|
356 |
+
Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame
|
357 |
+
`t` will start at time `t * hop_length`.
|
358 |
+
pad_mode (`str`, *optional*, defaults to `"reflect"`):
|
359 |
+
Padding mode used when `center` is `True`. Possible values are: `"constant"` (pad with zeros), `"edge"`
|
360 |
+
(pad with edge values), `"reflect"` (pads with mirrored values).
|
361 |
+
onesided (`bool`, *optional*, defaults to `True`):
|
362 |
+
If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
|
363 |
+
frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
|
364 |
+
preemphasis (`float`, *optional*)
|
365 |
+
Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
|
366 |
+
mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
|
367 |
+
The mel filter bank. If supplied, applies a this filter bank to create a mel spectrogram.
|
368 |
+
mel_floor (`float`, *optional*, defaults to 1e-10):
|
369 |
+
Minimum value of mel frequency banks.
|
370 |
+
log_mel (`str`, *optional*):
|
371 |
+
How to convert the spectrogram to log scale. Possible options are: `None` (don't convert), `"log"` (take
|
372 |
+
the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels). Can only be
|
373 |
+
used when `power` is not `None`.
|
374 |
+
reference (`float`, *optional*, defaults to 1.0):
|
375 |
+
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
|
376 |
+
the loudest part to 0 dB. Must be greater than zero.
|
377 |
+
min_value (`float`, *optional*, defaults to `1e-10`):
|
378 |
+
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
|
379 |
+
`log(0)`. For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an
|
380 |
+
amplitude spectrogram, the value `1e-5` corresponds to -100 dB. Must be greater than zero.
|
381 |
+
db_range (`float`, *optional*):
|
382 |
+
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
|
383 |
+
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
|
384 |
+
remove_dc_offset (`bool`, *optional*):
|
385 |
+
Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in
|
386 |
+
order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters.
|
387 |
+
dtype (`np.dtype`, *optional*, defaults to `np.float32`):
|
388 |
+
Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be
|
389 |
+
`np.complex64`.
|
390 |
+
|
391 |
+
Returns:
|
392 |
+
`nd.array` containing a spectrogram of shape `(num_frequency_bins, length)` for a regular spectrogram or shape
|
393 |
+
`(num_mel_filters, length)` for a mel spectrogram.
|
394 |
+
"""
|
395 |
+
window_length = len(window)
|
396 |
+
|
397 |
+
if fft_length is None:
|
398 |
+
fft_length = frame_length
|
399 |
+
|
400 |
+
if frame_length > fft_length:
|
401 |
+
raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
|
402 |
+
|
403 |
+
if window_length != frame_length:
|
404 |
+
raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
|
405 |
+
|
406 |
+
if hop_length <= 0:
|
407 |
+
raise ValueError("hop_length must be greater than zero")
|
408 |
+
|
409 |
+
if waveform.ndim != 1:
|
410 |
+
raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
|
411 |
+
|
412 |
+
if np.iscomplexobj(waveform):
|
413 |
+
raise ValueError("Complex-valued input waveforms are not currently supported")
|
414 |
+
|
415 |
+
# center pad the waveform
|
416 |
+
if center:
|
417 |
+
padding = [(int(frame_length // 2), int(frame_length // 2))]
|
418 |
+
waveform = np.pad(waveform, padding, mode=pad_mode)
|
419 |
+
|
420 |
+
# promote to float64, since np.fft uses float64 internally
|
421 |
+
waveform = waveform.astype(np.float64)
|
422 |
+
window = window.astype(np.float64)
|
423 |
+
|
424 |
+
# split waveform into frames of frame_length size
|
425 |
+
num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))
|
426 |
+
|
427 |
+
num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
|
428 |
+
spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)
|
429 |
+
|
430 |
+
# rfft is faster than fft
|
431 |
+
fft_func = np.fft.rfft if onesided else np.fft.fft
|
432 |
+
buffer = np.zeros(fft_length)
|
433 |
+
|
434 |
+
timestep = 0
|
435 |
+
for frame_idx in range(num_frames):
|
436 |
+
buffer[:frame_length] = waveform[timestep : timestep + frame_length]
|
437 |
+
|
438 |
+
if remove_dc_offset:
|
439 |
+
buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
|
440 |
+
|
441 |
+
if preemphasis is not None:
|
442 |
+
buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
|
443 |
+
buffer[0] *= 1 - preemphasis
|
444 |
+
|
445 |
+
buffer[:frame_length] *= window
|
446 |
+
|
447 |
+
spectrogram[frame_idx] = fft_func(buffer)
|
448 |
+
timestep += hop_length
|
449 |
+
|
450 |
+
# note: ** is much faster than np.power
|
451 |
+
if power is not None:
|
452 |
+
spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
|
453 |
+
|
454 |
+
spectrogram = spectrogram.T
|
455 |
+
|
456 |
+
if mel_filters is not None:
|
457 |
+
spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))
|
458 |
+
|
459 |
+
if power is not None and log_mel is not None:
|
460 |
+
if log_mel == "log":
|
461 |
+
spectrogram = np.log(spectrogram)
|
462 |
+
elif log_mel == "log10":
|
463 |
+
spectrogram = np.log10(spectrogram)
|
464 |
+
elif log_mel == "dB":
|
465 |
+
if power == 1.0:
|
466 |
+
spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)
|
467 |
+
elif power == 2.0:
|
468 |
+
spectrogram = power_to_db(spectrogram, reference, min_value, db_range)
|
469 |
+
else:
|
470 |
+
raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
|
471 |
+
else:
|
472 |
+
raise ValueError(f"Unknown log_mel option: {log_mel}")
|
473 |
+
|
474 |
+
spectrogram = np.asarray(spectrogram, dtype)
|
475 |
+
|
476 |
+
return spectrogram
|
477 |
+
|
478 |
+
|
479 |
+
def power_to_db(
|
480 |
+
spectrogram: np.ndarray,
|
481 |
+
reference: float = 1.0,
|
482 |
+
min_value: float = 1e-10,
|
483 |
+
db_range: Optional[float] = None,
|
484 |
+
) -> np.ndarray:
|
485 |
+
"""
|
486 |
+
Converts a power spectrogram to the decibel scale. This computes `10 * log10(spectrogram / reference)`, using basic
|
487 |
+
logarithm properties for numerical stability.
|
488 |
+
|
489 |
+
The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
|
490 |
+
linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
|
491 |
+
This means that large variations in energy may not sound all that different if the sound is loud to begin with.
|
492 |
+
This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
|
493 |
+
|
494 |
+
Based on the implementation of `librosa.power_to_db`.
|
495 |
+
|
496 |
+
Args:
|
497 |
+
spectrogram (`np.ndarray`):
|
498 |
+
The input power (mel) spectrogram. Note that a power spectrogram has the amplitudes squared!
|
499 |
+
reference (`float`, *optional*, defaults to 1.0):
|
500 |
+
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
|
501 |
+
the loudest part to 0 dB. Must be greater than zero.
|
502 |
+
min_value (`float`, *optional*, defaults to `1e-10`):
|
503 |
+
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
|
504 |
+
`log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
|
505 |
+
db_range (`float`, *optional*):
|
506 |
+
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
|
507 |
+
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
|
508 |
+
|
509 |
+
Returns:
|
510 |
+
`np.ndarray`: the spectrogram in decibels
|
511 |
+
"""
|
512 |
+
if reference <= 0.0:
|
513 |
+
raise ValueError("reference must be greater than zero")
|
514 |
+
if min_value <= 0.0:
|
515 |
+
raise ValueError("min_value must be greater than zero")
|
516 |
+
|
517 |
+
reference = max(min_value, reference)
|
518 |
+
|
519 |
+
spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
|
520 |
+
spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
|
521 |
+
|
522 |
+
if db_range is not None:
|
523 |
+
if db_range <= 0.0:
|
524 |
+
raise ValueError("db_range must be greater than zero")
|
525 |
+
spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
|
526 |
+
|
527 |
+
return spectrogram
|
528 |
+
|
529 |
+
|
530 |
+
def amplitude_to_db(
|
531 |
+
spectrogram: np.ndarray,
|
532 |
+
reference: float = 1.0,
|
533 |
+
min_value: float = 1e-5,
|
534 |
+
db_range: Optional[float] = None,
|
535 |
+
) -> np.ndarray:
|
536 |
+
"""
|
537 |
+
Converts an amplitude spectrogram to the decibel scale. This computes `20 * log10(spectrogram / reference)`, using
|
538 |
+
basic logarithm properties for numerical stability.
|
539 |
+
|
540 |
+
The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
|
541 |
+
linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
|
542 |
+
This means that large variations in energy may not sound all that different if the sound is loud to begin with.
|
543 |
+
This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
|
544 |
+
|
545 |
+
Args:
|
546 |
+
spectrogram (`np.ndarray`):
|
547 |
+
The input amplitude (mel) spectrogram.
|
548 |
+
reference (`float`, *optional*, defaults to 1.0):
|
549 |
+
Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
|
550 |
+
the loudest part to 0 dB. Must be greater than zero.
|
551 |
+
min_value (`float`, *optional*, defaults to `1e-5`):
|
552 |
+
The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
|
553 |
+
`log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
|
554 |
+
db_range (`float`, *optional*):
|
555 |
+
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
|
556 |
+
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
|
557 |
+
|
558 |
+
Returns:
|
559 |
+
`np.ndarray`: the spectrogram in decibels
|
560 |
+
"""
|
561 |
+
if reference <= 0.0:
|
562 |
+
raise ValueError("reference must be greater than zero")
|
563 |
+
if min_value <= 0.0:
|
564 |
+
raise ValueError("min_value must be greater than zero")
|
565 |
+
|
566 |
+
reference = max(min_value, reference)
|
567 |
+
|
568 |
+
spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
|
569 |
+
spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
|
570 |
+
|
571 |
+
if db_range is not None:
|
572 |
+
if db_range <= 0.0:
|
573 |
+
raise ValueError("db_range must be greater than zero")
|
574 |
+
spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
|
575 |
+
|
576 |
+
return spectrogram
|
577 |
+
|
578 |
+
|
579 |
+
### deprecated functions below this line ###
|
580 |
+
|
581 |
+
|
582 |
+
def get_mel_filter_banks(
|
583 |
+
nb_frequency_bins: int,
|
584 |
+
nb_mel_filters: int,
|
585 |
+
frequency_min: float,
|
586 |
+
frequency_max: float,
|
587 |
+
sample_rate: int,
|
588 |
+
norm: Optional[str] = None,
|
589 |
+
mel_scale: str = "htk",
|
590 |
+
) -> np.array:
|
591 |
+
warnings.warn(
|
592 |
+
"The function `get_mel_filter_banks` is deprecated and will be removed in version 4.31.0 of Transformers",
|
593 |
+
FutureWarning,
|
594 |
+
)
|
595 |
+
return mel_filter_bank(
|
596 |
+
num_frequency_bins=nb_frequency_bins,
|
597 |
+
num_mel_filters=nb_mel_filters,
|
598 |
+
min_frequency=frequency_min,
|
599 |
+
max_frequency=frequency_max,
|
600 |
+
sampling_rate=sample_rate,
|
601 |
+
norm=norm,
|
602 |
+
mel_scale=mel_scale,
|
603 |
+
)
|
604 |
+
|
605 |
+
|
606 |
+
def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int = 400, center: bool = True):
|
607 |
+
"""
|
608 |
+
In order to compute the short time fourier transform, the waveform needs to be split in overlapping windowed
|
609 |
+
segments called `frames`.
|
610 |
+
|
611 |
+
The window length (window_length) defines how much of the signal is contained in each frame, while the hop length
|
612 |
+
defines the step between the beginning of each new frame.
|
613 |
+
|
614 |
+
|
615 |
+
Args:
|
616 |
+
waveform (`np.array` of shape `(sample_length,)`):
|
617 |
+
The raw waveform which will be split into smaller chunks.
|
618 |
+
hop_length (`int`, *optional*, defaults to 160):
|
619 |
+
Step between each window of the waveform.
|
620 |
+
fft_window_size (`int`, *optional*, defaults to 400):
|
621 |
+
Defines the size of the window.
|
622 |
+
center (`bool`, defaults to `True`):
|
623 |
+
Whether or not to center each frame around the middle of the frame. Centering is done by reflecting the
|
624 |
+
waveform on the left and on the right.
|
625 |
+
|
626 |
+
Return:
|
627 |
+
framed_waveform (`np.array` of shape `(waveform.shape // hop_length , fft_window_size)`):
|
628 |
+
The framed waveforms that can be fed to `np.fft`.
|
629 |
+
"""
|
630 |
+
warnings.warn(
|
631 |
+
"The function `fram_wave` is deprecated and will be removed in version 4.31.0 of Transformers",
|
632 |
+
FutureWarning,
|
633 |
+
)
|
634 |
+
frames = []
|
635 |
+
for i in range(0, waveform.shape[0] + 1, hop_length):
|
636 |
+
if center:
|
637 |
+
half_window = (fft_window_size - 1) // 2 + 1
|
638 |
+
start = i - half_window if i > half_window else 0
|
639 |
+
end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0]
|
640 |
+
frame = waveform[start:end]
|
641 |
+
if start == 0:
|
642 |
+
padd_width = (-i + half_window, 0)
|
643 |
+
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
|
644 |
+
|
645 |
+
elif end == waveform.shape[0]:
|
646 |
+
padd_width = (0, (i - waveform.shape[0] + half_window))
|
647 |
+
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
|
648 |
+
|
649 |
+
else:
|
650 |
+
frame = waveform[i : i + fft_window_size]
|
651 |
+
frame_width = frame.shape[0]
|
652 |
+
if frame_width < waveform.shape[0]:
|
653 |
+
frame = np.lib.pad(
|
654 |
+
frame, pad_width=(0, fft_window_size - frame_width), mode="constant", constant_values=0
|
655 |
+
)
|
656 |
+
frames.append(frame)
|
657 |
+
|
658 |
+
frames = np.stack(frames, 0)
|
659 |
+
return frames
|
660 |
+
|
661 |
+
|
662 |
+
def stft(frames: np.array, windowing_function: np.array, fft_window_size: int = None):
|
663 |
+
"""
|
664 |
+
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same results
|
665 |
+
as `torch.stft`.
|
666 |
+
|
667 |
+
Args:
|
668 |
+
frames (`np.array` of dimension `(num_frames, fft_window_size)`):
|
669 |
+
A framed audio signal obtained using `audio_utils.fram_wav`.
|
670 |
+
windowing_function (`np.array` of dimension `(nb_frequency_bins, nb_mel_filters)`:
|
671 |
+
A array reprensenting the function that will be used to reduces the amplitude of the discontinuities at the
|
672 |
+
boundaries of each frame when computing the STFT. Each frame will be multiplied by the windowing_function.
|
673 |
+
For more information on the discontinuities, called *Spectral leakage*, refer to [this
|
674 |
+
tutorial]https://download.ni.com/evaluation/pxi/Understanding%20FFTs%20and%20Windowing.pdf
|
675 |
+
fft_window_size (`int`, *optional*):
|
676 |
+
Size of the window om which the Fourier transform is applied. This controls the frequency resolution of the
|
677 |
+
spectrogram. 400 means that the fourrier transform is computed on windows of 400 samples. The number of
|
678 |
+
frequency bins (`nb_frequency_bins`) used to divide the window into equal strips is equal to
|
679 |
+
`(1+fft_window_size)//2`. An increase of the fft_window_size slows the calculus time proportionnally.
|
680 |
+
|
681 |
+
Example:
|
682 |
+
|
683 |
+
```python
|
684 |
+
>>> from transformers.audio_utils import stft, fram_wave
|
685 |
+
>>> import numpy as np
|
686 |
+
|
687 |
+
>>> audio = np.random.rand(50)
|
688 |
+
>>> fft_window_size = 10
|
689 |
+
>>> hop_length = 2
|
690 |
+
>>> framed_audio = fram_wave(audio, hop_length, fft_window_size)
|
691 |
+
>>> spectrogram = stft(framed_audio, np.hanning(fft_window_size + 1))
|
692 |
+
```
|
693 |
+
|
694 |
+
Returns:
|
695 |
+
spectrogram (`np.ndarray`):
|
696 |
+
A spectrogram of shape `(num_frames, nb_frequency_bins)` obtained using the STFT algorithm
|
697 |
+
"""
|
698 |
+
warnings.warn(
|
699 |
+
"The function `stft` is deprecated and will be removed in version 4.31.0 of Transformers",
|
700 |
+
FutureWarning,
|
701 |
+
)
|
702 |
+
frame_size = frames.shape[1]
|
703 |
+
|
704 |
+
if fft_window_size is None:
|
705 |
+
fft_window_size = frame_size
|
706 |
+
|
707 |
+
if fft_window_size < frame_size:
|
708 |
+
raise ValueError("FFT size must greater or equal the frame size")
|
709 |
+
# number of FFT bins to store
|
710 |
+
nb_frequency_bins = (fft_window_size >> 1) + 1
|
711 |
+
|
712 |
+
spectrogram = np.empty((len(frames), nb_frequency_bins), dtype=np.complex64)
|
713 |
+
fft_signal = np.zeros(fft_window_size)
|
714 |
+
|
715 |
+
for f, frame in enumerate(frames):
|
716 |
+
if windowing_function is not None:
|
717 |
+
np.multiply(frame, windowing_function, out=fft_signal[:frame_size])
|
718 |
+
else:
|
719 |
+
fft_signal[:frame_size] = frame
|
720 |
+
spectrogram[f] = np.fft.fft(fft_signal, axis=0)[:nb_frequency_bins]
|
721 |
+
return spectrogram.T
|
transformers_4_35_0/benchmark/__init__.py
ADDED
File without changes
|
transformers_4_35_0/benchmark/benchmark.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Benchmarking the library on inference and training in PyTorch.
|
18 |
+
"""
|
19 |
+
|
20 |
+
|
21 |
+
import timeit
|
22 |
+
from typing import Callable, Optional
|
23 |
+
|
24 |
+
from ..configuration_utils import PretrainedConfig
|
25 |
+
from ..models.auto.modeling_auto import MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING
|
26 |
+
from ..utils import is_py3nvml_available, is_torch_available, logging
|
27 |
+
from .benchmark_utils import (
|
28 |
+
Benchmark,
|
29 |
+
Memory,
|
30 |
+
MemorySummary,
|
31 |
+
measure_peak_memory_cpu,
|
32 |
+
start_memory_tracing,
|
33 |
+
stop_memory_tracing,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
if is_torch_available():
|
38 |
+
import torch
|
39 |
+
|
40 |
+
from .benchmark_args import PyTorchBenchmarkArguments
|
41 |
+
|
42 |
+
|
43 |
+
if is_py3nvml_available():
|
44 |
+
import py3nvml.py3nvml as nvml
|
45 |
+
|
46 |
+
|
47 |
+
logger = logging.get_logger(__name__)
|
48 |
+
|
49 |
+
|
50 |
+
class PyTorchBenchmark(Benchmark):
|
51 |
+
args: PyTorchBenchmarkArguments
|
52 |
+
configs: PretrainedConfig
|
53 |
+
framework: str = "PyTorch"
|
54 |
+
|
55 |
+
@property
|
56 |
+
def framework_version(self):
|
57 |
+
return torch.__version__
|
58 |
+
|
59 |
+
def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
|
60 |
+
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
|
61 |
+
return self._measure_speed(_inference)
|
62 |
+
|
63 |
+
def _inference_memory(
|
64 |
+
self, model_name: str, batch_size: int, sequence_length: int
|
65 |
+
) -> [Memory, Optional[MemorySummary]]:
|
66 |
+
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
|
67 |
+
return self._measure_memory(_inference)
|
68 |
+
|
69 |
+
def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
|
70 |
+
_train = self._prepare_train_func(model_name, batch_size, sequence_length)
|
71 |
+
return self._measure_speed(_train)
|
72 |
+
|
73 |
+
def _train_memory(
|
74 |
+
self, model_name: str, batch_size: int, sequence_length: int
|
75 |
+
) -> [Memory, Optional[MemorySummary]]:
|
76 |
+
_train = self._prepare_train_func(model_name, batch_size, sequence_length)
|
77 |
+
return self._measure_memory(_train)
|
78 |
+
|
79 |
+
def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
|
80 |
+
config = self.config_dict[model_name]
|
81 |
+
|
82 |
+
if self.args.torchscript:
|
83 |
+
config.torchscript = True
|
84 |
+
|
85 |
+
has_model_class_in_config = (
|
86 |
+
hasattr(config, "architectures")
|
87 |
+
and isinstance(config.architectures, list)
|
88 |
+
and len(config.architectures) > 0
|
89 |
+
)
|
90 |
+
if not self.args.only_pretrain_model and has_model_class_in_config:
|
91 |
+
try:
|
92 |
+
model_class = config.architectures[0]
|
93 |
+
transformers_module = __import__("transformers", fromlist=[model_class])
|
94 |
+
model_cls = getattr(transformers_module, model_class)
|
95 |
+
model = model_cls(config)
|
96 |
+
except ImportError:
|
97 |
+
raise ImportError(
|
98 |
+
f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
|
99 |
+
" set `--only_pretrain_model` or `args.only_pretrain_model=True`."
|
100 |
+
)
|
101 |
+
else:
|
102 |
+
model = MODEL_MAPPING[config.__class__](config)
|
103 |
+
|
104 |
+
model.eval()
|
105 |
+
model.to(self.args.device)
|
106 |
+
|
107 |
+
# encoder-decoder has vocab size saved differently
|
108 |
+
vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
|
109 |
+
input_ids = torch.randint(vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device)
|
110 |
+
|
111 |
+
if self.args.fp16:
|
112 |
+
logger.info("Running training in Mixed Precision...")
|
113 |
+
if not self.args.is_gpu:
|
114 |
+
raise ValueError("Mixed precision is possible only for GPU.")
|
115 |
+
# amp seems to have memory leaks so that memory usage
|
116 |
+
# is measured using .half() for now https://github.com/NVIDIA/apex/issues/439
|
117 |
+
model.half()
|
118 |
+
|
119 |
+
if self.args.torchscript:
|
120 |
+
with torch.no_grad():
|
121 |
+
inference_model = torch.jit.trace(model, input_ids)
|
122 |
+
else:
|
123 |
+
inference_model = model
|
124 |
+
|
125 |
+
def encoder_decoder_forward():
|
126 |
+
with torch.no_grad():
|
127 |
+
outputs = inference_model(input_ids, decoder_input_ids=input_ids)
|
128 |
+
return outputs
|
129 |
+
|
130 |
+
def encoder_forward():
|
131 |
+
with torch.no_grad():
|
132 |
+
outputs = inference_model(input_ids)
|
133 |
+
return outputs
|
134 |
+
|
135 |
+
_forward = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward
|
136 |
+
return _forward
|
137 |
+
|
138 |
+
def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
|
139 |
+
config = self.config_dict[model_name]
|
140 |
+
|
141 |
+
has_model_class_in_config = (
|
142 |
+
hasattr(config, "architectures")
|
143 |
+
and isinstance(config.architectures, list)
|
144 |
+
and len(config.architectures) > 0
|
145 |
+
)
|
146 |
+
if not self.args.only_pretrain_model and has_model_class_in_config:
|
147 |
+
try:
|
148 |
+
model_class = config.architectures[0]
|
149 |
+
transformers_module = __import__("transformers", fromlist=[model_class])
|
150 |
+
model_cls = getattr(transformers_module, model_class)
|
151 |
+
model = model_cls(config)
|
152 |
+
except ImportError:
|
153 |
+
raise ImportError(
|
154 |
+
f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
|
155 |
+
" set `--only_pretrain_model` or `args.only_pretrain_model=True`."
|
156 |
+
)
|
157 |
+
else:
|
158 |
+
model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
|
159 |
+
|
160 |
+
if self.args.torchscript:
|
161 |
+
raise NotImplementedError("Training for torchscript is currently not implemented")
|
162 |
+
else:
|
163 |
+
train_model = model
|
164 |
+
|
165 |
+
model.train()
|
166 |
+
model.to(self.args.device)
|
167 |
+
|
168 |
+
# encoder-decoder has vocab size saved differently
|
169 |
+
vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
|
170 |
+
input_ids = torch.randint(vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device)
|
171 |
+
|
172 |
+
if self.args.fp16:
|
173 |
+
logger.info("Running training in Mixed Precision...")
|
174 |
+
if not self.args.is_gpu:
|
175 |
+
raise ValueError("Mixed precision is possible only for GPU.")
|
176 |
+
|
177 |
+
# amp seems to have memory leaks so that memory usage
|
178 |
+
# is measured using .half() for now https://github.com/NVIDIA/apex/issues/439
|
179 |
+
model.half()
|
180 |
+
|
181 |
+
def compute_loss_and_backprob_encoder():
|
182 |
+
loss = train_model(input_ids, labels=input_ids)[0]
|
183 |
+
loss.backward()
|
184 |
+
return loss
|
185 |
+
|
186 |
+
def compute_loss_and_backprob_encoder_decoder():
|
187 |
+
loss = train_model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0]
|
188 |
+
loss.backward()
|
189 |
+
return loss
|
190 |
+
|
191 |
+
_train = (
|
192 |
+
compute_loss_and_backprob_encoder_decoder
|
193 |
+
if config.is_encoder_decoder
|
194 |
+
else compute_loss_and_backprob_encoder
|
195 |
+
)
|
196 |
+
return _train
|
197 |
+
|
198 |
+
def _measure_speed(self, func) -> float:
|
199 |
+
try:
|
200 |
+
if self.args.is_tpu or self.args.torchscript:
|
201 |
+
# run additional 10 times to stabilize compilation for tpu and torchscript
|
202 |
+
logger.info("Do inference on TPU or torchscript. Running model 5 times to stabilize compilation")
|
203 |
+
timeit.repeat(
|
204 |
+
func,
|
205 |
+
repeat=1,
|
206 |
+
number=5,
|
207 |
+
)
|
208 |
+
|
209 |
+
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
|
210 |
+
runtimes = timeit.repeat(
|
211 |
+
func,
|
212 |
+
repeat=self.args.repeat,
|
213 |
+
number=10,
|
214 |
+
)
|
215 |
+
|
216 |
+
if self.args.is_tpu and self.args.torch_xla_tpu_print_metrics:
|
217 |
+
import torch_xla.debug.metrics as met
|
218 |
+
|
219 |
+
self.print_fn(met.metrics_report())
|
220 |
+
|
221 |
+
return min(runtimes) / 10.0
|
222 |
+
except RuntimeError as e:
|
223 |
+
self.print_fn(f"Doesn't fit on GPU. {e}")
|
224 |
+
return "N/A"
|
225 |
+
|
226 |
+
def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
|
227 |
+
try:
|
228 |
+
if self.args.trace_memory_line_by_line:
|
229 |
+
trace = start_memory_tracing("transformers")
|
230 |
+
|
231 |
+
if self.args.is_tpu:
|
232 |
+
# tpu
|
233 |
+
raise NotImplementedError(
|
234 |
+
"Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with"
|
235 |
+
" `--no-memory` or `args.memory=False`"
|
236 |
+
)
|
237 |
+
elif self.args.is_gpu:
|
238 |
+
if not is_py3nvml_available():
|
239 |
+
logger.warning(
|
240 |
+
"py3nvml not installed, we won't log GPU memory usage. "
|
241 |
+
"Install py3nvml (pip install py3nvml) to log information about GPU."
|
242 |
+
)
|
243 |
+
memory = "N/A"
|
244 |
+
else:
|
245 |
+
logger.info(
|
246 |
+
"Measuring total GPU usage on GPU device. Make sure to not have additional processes running"
|
247 |
+
" on the same GPU."
|
248 |
+
)
|
249 |
+
# init nvml
|
250 |
+
nvml.nvmlInit()
|
251 |
+
func()
|
252 |
+
handle = nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx)
|
253 |
+
meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
|
254 |
+
max_bytes_in_use = meminfo.used
|
255 |
+
memory = Memory(max_bytes_in_use)
|
256 |
+
# shutdown nvml
|
257 |
+
nvml.nvmlShutdown()
|
258 |
+
else:
|
259 |
+
# cpu
|
260 |
+
memory_bytes = measure_peak_memory_cpu(func)
|
261 |
+
memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
|
262 |
+
|
263 |
+
if self.args.trace_memory_line_by_line:
|
264 |
+
summary = stop_memory_tracing(trace)
|
265 |
+
else:
|
266 |
+
summary = None
|
267 |
+
|
268 |
+
return memory, summary
|
269 |
+
except RuntimeError as e:
|
270 |
+
self.print_fn(f"Doesn't fit on GPU. {e}")
|
271 |
+
return "N/A", None
|
transformers_4_35_0/benchmark/benchmark_args.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from typing import Tuple
|
19 |
+
|
20 |
+
from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, requires_backends
|
21 |
+
from .benchmark_args_utils import BenchmarkArguments
|
22 |
+
|
23 |
+
|
24 |
+
if is_torch_available():
|
25 |
+
import torch
|
26 |
+
|
27 |
+
if is_torch_tpu_available(check_device=False):
|
28 |
+
import torch_xla.core.xla_model as xm
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class PyTorchBenchmarkArguments(BenchmarkArguments):
|
36 |
+
deprecated_args = [
|
37 |
+
"no_inference",
|
38 |
+
"no_cuda",
|
39 |
+
"no_tpu",
|
40 |
+
"no_speed",
|
41 |
+
"no_memory",
|
42 |
+
"no_env_print",
|
43 |
+
"no_multi_process",
|
44 |
+
]
|
45 |
+
|
46 |
+
def __init__(self, **kwargs):
|
47 |
+
"""
|
48 |
+
This __init__ is there for legacy code. When removing deprecated args completely, the class can simply be
|
49 |
+
deleted
|
50 |
+
"""
|
51 |
+
for deprecated_arg in self.deprecated_args:
|
52 |
+
if deprecated_arg in kwargs:
|
53 |
+
positive_arg = deprecated_arg[3:]
|
54 |
+
setattr(self, positive_arg, not kwargs.pop(deprecated_arg))
|
55 |
+
logger.warning(
|
56 |
+
f"{deprecated_arg} is depreciated. Please use --no_{positive_arg} or"
|
57 |
+
f" {positive_arg}={kwargs[positive_arg]}"
|
58 |
+
)
|
59 |
+
|
60 |
+
self.torchscript = kwargs.pop("torchscript", self.torchscript)
|
61 |
+
self.torch_xla_tpu_print_metrics = kwargs.pop("torch_xla_tpu_print_metrics", self.torch_xla_tpu_print_metrics)
|
62 |
+
self.fp16_opt_level = kwargs.pop("fp16_opt_level", self.fp16_opt_level)
|
63 |
+
super().__init__(**kwargs)
|
64 |
+
|
65 |
+
torchscript: bool = field(default=False, metadata={"help": "Trace the models using torchscript"})
|
66 |
+
torch_xla_tpu_print_metrics: bool = field(default=False, metadata={"help": "Print Xla/PyTorch tpu metrics"})
|
67 |
+
fp16_opt_level: str = field(
|
68 |
+
default="O1",
|
69 |
+
metadata={
|
70 |
+
"help": (
|
71 |
+
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
|
72 |
+
"See details at https://nvidia.github.io/apex/amp.html"
|
73 |
+
)
|
74 |
+
},
|
75 |
+
)
|
76 |
+
|
77 |
+
@cached_property
|
78 |
+
def _setup_devices(self) -> Tuple["torch.device", int]:
|
79 |
+
requires_backends(self, ["torch"])
|
80 |
+
logger.info("PyTorch: setting up devices")
|
81 |
+
if not self.cuda:
|
82 |
+
device = torch.device("cpu")
|
83 |
+
n_gpu = 0
|
84 |
+
elif is_torch_tpu_available():
|
85 |
+
device = xm.xla_device()
|
86 |
+
n_gpu = 0
|
87 |
+
else:
|
88 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
89 |
+
n_gpu = torch.cuda.device_count()
|
90 |
+
return device, n_gpu
|
91 |
+
|
92 |
+
@property
|
93 |
+
def is_tpu(self):
|
94 |
+
return is_torch_tpu_available() and self.tpu
|
95 |
+
|
96 |
+
@property
|
97 |
+
def device_idx(self) -> int:
|
98 |
+
requires_backends(self, ["torch"])
|
99 |
+
# TODO(PVP): currently only single GPU is supported
|
100 |
+
return torch.cuda.current_device()
|
101 |
+
|
102 |
+
@property
|
103 |
+
def device(self) -> "torch.device":
|
104 |
+
requires_backends(self, ["torch"])
|
105 |
+
return self._setup_devices[0]
|
106 |
+
|
107 |
+
@property
|
108 |
+
def n_gpu(self):
|
109 |
+
requires_backends(self, ["torch"])
|
110 |
+
return self._setup_devices[1]
|
111 |
+
|
112 |
+
@property
|
113 |
+
def is_gpu(self):
|
114 |
+
return self.n_gpu > 0
|
transformers_4_35_0/benchmark/benchmark_args_tf.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from typing import Tuple
|
19 |
+
|
20 |
+
from ..utils import cached_property, is_tf_available, logging, requires_backends
|
21 |
+
from .benchmark_args_utils import BenchmarkArguments
|
22 |
+
|
23 |
+
|
24 |
+
if is_tf_available():
|
25 |
+
import tensorflow as tf
|
26 |
+
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
@dataclass
|
32 |
+
class TensorFlowBenchmarkArguments(BenchmarkArguments):
|
33 |
+
deprecated_args = [
|
34 |
+
"no_inference",
|
35 |
+
"no_cuda",
|
36 |
+
"no_tpu",
|
37 |
+
"no_speed",
|
38 |
+
"no_memory",
|
39 |
+
"no_env_print",
|
40 |
+
"no_multi_process",
|
41 |
+
]
|
42 |
+
|
43 |
+
def __init__(self, **kwargs):
|
44 |
+
"""
|
45 |
+
This __init__ is there for legacy code. When removing deprecated args completely, the class can simply be
|
46 |
+
deleted
|
47 |
+
"""
|
48 |
+
for deprecated_arg in self.deprecated_args:
|
49 |
+
if deprecated_arg in kwargs:
|
50 |
+
positive_arg = deprecated_arg[3:]
|
51 |
+
kwargs[positive_arg] = not kwargs.pop(deprecated_arg)
|
52 |
+
logger.warning(
|
53 |
+
f"{deprecated_arg} is depreciated. Please use --no-{positive_arg} or"
|
54 |
+
f" {positive_arg}={kwargs[positive_arg]}"
|
55 |
+
)
|
56 |
+
self.tpu_name = kwargs.pop("tpu_name", self.tpu_name)
|
57 |
+
self.device_idx = kwargs.pop("device_idx", self.device_idx)
|
58 |
+
self.eager_mode = kwargs.pop("eager_mode", self.eager_mode)
|
59 |
+
self.use_xla = kwargs.pop("use_xla", self.use_xla)
|
60 |
+
super().__init__(**kwargs)
|
61 |
+
|
62 |
+
tpu_name: str = field(
|
63 |
+
default=None,
|
64 |
+
metadata={"help": "Name of TPU"},
|
65 |
+
)
|
66 |
+
device_idx: int = field(
|
67 |
+
default=0,
|
68 |
+
metadata={"help": "CPU / GPU device index. Defaults to 0."},
|
69 |
+
)
|
70 |
+
eager_mode: bool = field(default=False, metadata={"help": "Benchmark models in eager model."})
|
71 |
+
use_xla: bool = field(
|
72 |
+
default=False,
|
73 |
+
metadata={
|
74 |
+
"help": "Benchmark models using XLA JIT compilation. Note that `eager_model` has to be set to `False`."
|
75 |
+
},
|
76 |
+
)
|
77 |
+
|
78 |
+
@cached_property
|
79 |
+
def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]:
|
80 |
+
requires_backends(self, ["tf"])
|
81 |
+
tpu = None
|
82 |
+
if self.tpu:
|
83 |
+
try:
|
84 |
+
if self.tpu_name:
|
85 |
+
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(self.tpu_name)
|
86 |
+
else:
|
87 |
+
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
|
88 |
+
except ValueError:
|
89 |
+
tpu = None
|
90 |
+
return tpu
|
91 |
+
|
92 |
+
@cached_property
|
93 |
+
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]:
|
94 |
+
requires_backends(self, ["tf"])
|
95 |
+
if self.is_tpu:
|
96 |
+
tf.config.experimental_connect_to_cluster(self._setup_tpu)
|
97 |
+
tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)
|
98 |
+
|
99 |
+
strategy = tf.distribute.TPUStrategy(self._setup_tpu)
|
100 |
+
else:
|
101 |
+
# currently no multi gpu is allowed
|
102 |
+
if self.is_gpu:
|
103 |
+
# TODO: Currently only single GPU is supported
|
104 |
+
tf.config.set_visible_devices(self.gpu_list[self.device_idx], "GPU")
|
105 |
+
strategy = tf.distribute.OneDeviceStrategy(device=f"/gpu:{self.device_idx}")
|
106 |
+
else:
|
107 |
+
tf.config.set_visible_devices([], "GPU") # disable GPU
|
108 |
+
strategy = tf.distribute.OneDeviceStrategy(device=f"/cpu:{self.device_idx}")
|
109 |
+
|
110 |
+
return strategy
|
111 |
+
|
112 |
+
@property
|
113 |
+
def is_tpu(self) -> bool:
|
114 |
+
requires_backends(self, ["tf"])
|
115 |
+
return self._setup_tpu is not None
|
116 |
+
|
117 |
+
@property
|
118 |
+
def strategy(self) -> "tf.distribute.Strategy":
|
119 |
+
requires_backends(self, ["tf"])
|
120 |
+
return self._setup_strategy
|
121 |
+
|
122 |
+
@property
|
123 |
+
def gpu_list(self):
|
124 |
+
requires_backends(self, ["tf"])
|
125 |
+
return tf.config.list_physical_devices("GPU")
|
126 |
+
|
127 |
+
@property
|
128 |
+
def n_gpu(self) -> int:
|
129 |
+
requires_backends(self, ["tf"])
|
130 |
+
if self.cuda:
|
131 |
+
return len(self.gpu_list)
|
132 |
+
return 0
|
133 |
+
|
134 |
+
@property
|
135 |
+
def is_gpu(self) -> bool:
|
136 |
+
return self.n_gpu > 0
|
transformers_4_35_0/benchmark/benchmark_args_utils.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
import json
|
19 |
+
import warnings
|
20 |
+
from dataclasses import dataclass, field
|
21 |
+
from time import time
|
22 |
+
from typing import List
|
23 |
+
|
24 |
+
from ..utils import logging
|
25 |
+
|
26 |
+
|
27 |
+
logger = logging.get_logger(__name__)
|
28 |
+
|
29 |
+
|
30 |
+
def list_field(default=None, metadata=None):
|
31 |
+
return field(default_factory=lambda: default, metadata=metadata)
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class BenchmarkArguments:
|
36 |
+
"""
|
37 |
+
BenchMarkArguments are arguments we use in our benchmark scripts **which relate to the training loop itself**.
|
38 |
+
|
39 |
+
Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command
|
40 |
+
line.
|
41 |
+
"""
|
42 |
+
|
43 |
+
models: List[str] = list_field(
|
44 |
+
default=[],
|
45 |
+
metadata={
|
46 |
+
"help": (
|
47 |
+
"Model checkpoints to be provided to the AutoModel classes. Leave blank to benchmark the base version"
|
48 |
+
" of all available models"
|
49 |
+
)
|
50 |
+
},
|
51 |
+
)
|
52 |
+
|
53 |
+
batch_sizes: List[int] = list_field(
|
54 |
+
default=[8], metadata={"help": "List of batch sizes for which memory and time performance will be evaluated"}
|
55 |
+
)
|
56 |
+
|
57 |
+
sequence_lengths: List[int] = list_field(
|
58 |
+
default=[8, 32, 128, 512],
|
59 |
+
metadata={"help": "List of sequence lengths for which memory and time performance will be evaluated"},
|
60 |
+
)
|
61 |
+
|
62 |
+
inference: bool = field(
|
63 |
+
default=True,
|
64 |
+
metadata={"help": "Whether to benchmark inference of model. Inference can be disabled via --no-inference."},
|
65 |
+
)
|
66 |
+
cuda: bool = field(
|
67 |
+
default=True,
|
68 |
+
metadata={"help": "Whether to run on available cuda devices. Cuda can be disabled via --no-cuda."},
|
69 |
+
)
|
70 |
+
tpu: bool = field(
|
71 |
+
default=True, metadata={"help": "Whether to run on available tpu devices. TPU can be disabled via --no-tpu."}
|
72 |
+
)
|
73 |
+
fp16: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."})
|
74 |
+
training: bool = field(default=False, metadata={"help": "Benchmark training of model"})
|
75 |
+
verbose: bool = field(default=False, metadata={"help": "Verbose memory tracing"})
|
76 |
+
speed: bool = field(
|
77 |
+
default=True,
|
78 |
+
metadata={"help": "Whether to perform speed measurements. Speed measurements can be disabled via --no-speed."},
|
79 |
+
)
|
80 |
+
memory: bool = field(
|
81 |
+
default=True,
|
82 |
+
metadata={
|
83 |
+
"help": "Whether to perform memory measurements. Memory measurements can be disabled via --no-memory"
|
84 |
+
},
|
85 |
+
)
|
86 |
+
trace_memory_line_by_line: bool = field(default=False, metadata={"help": "Trace memory line by line"})
|
87 |
+
save_to_csv: bool = field(default=False, metadata={"help": "Save result to a CSV file"})
|
88 |
+
log_print: bool = field(default=False, metadata={"help": "Save all print statements in a log file"})
|
89 |
+
env_print: bool = field(default=False, metadata={"help": "Whether to print environment information"})
|
90 |
+
multi_process: bool = field(
|
91 |
+
default=True,
|
92 |
+
metadata={
|
93 |
+
"help": (
|
94 |
+
"Whether to use multiprocessing for memory and speed measurement. It is highly recommended to use"
|
95 |
+
" multiprocessing for accurate CPU and GPU memory measurements. This option should only be disabled"
|
96 |
+
" for debugging / testing and on TPU."
|
97 |
+
)
|
98 |
+
},
|
99 |
+
)
|
100 |
+
inference_time_csv_file: str = field(
|
101 |
+
default=f"inference_time_{round(time())}.csv",
|
102 |
+
metadata={"help": "CSV filename used if saving time results to csv."},
|
103 |
+
)
|
104 |
+
inference_memory_csv_file: str = field(
|
105 |
+
default=f"inference_memory_{round(time())}.csv",
|
106 |
+
metadata={"help": "CSV filename used if saving memory results to csv."},
|
107 |
+
)
|
108 |
+
train_time_csv_file: str = field(
|
109 |
+
default=f"train_time_{round(time())}.csv",
|
110 |
+
metadata={"help": "CSV filename used if saving time results to csv for training."},
|
111 |
+
)
|
112 |
+
train_memory_csv_file: str = field(
|
113 |
+
default=f"train_memory_{round(time())}.csv",
|
114 |
+
metadata={"help": "CSV filename used if saving memory results to csv for training."},
|
115 |
+
)
|
116 |
+
env_info_csv_file: str = field(
|
117 |
+
default=f"env_info_{round(time())}.csv",
|
118 |
+
metadata={"help": "CSV filename used if saving environment information."},
|
119 |
+
)
|
120 |
+
log_filename: str = field(
|
121 |
+
default=f"log_{round(time())}.csv",
|
122 |
+
metadata={"help": "Log filename used if print statements are saved in log."},
|
123 |
+
)
|
124 |
+
repeat: int = field(default=3, metadata={"help": "Times an experiment will be run."})
|
125 |
+
only_pretrain_model: bool = field(
|
126 |
+
default=False,
|
127 |
+
metadata={
|
128 |
+
"help": (
|
129 |
+
"Instead of loading the model as defined in `config.architectures` if exists, just load the pretrain"
|
130 |
+
" model weights."
|
131 |
+
)
|
132 |
+
},
|
133 |
+
)
|
134 |
+
|
135 |
+
def __post_init__(self):
|
136 |
+
warnings.warn(
|
137 |
+
f"The class {self.__class__} is deprecated. Hugging Face Benchmarking utils"
|
138 |
+
" are deprecated in general and it is advised to use external Benchmarking libraries "
|
139 |
+
" to benchmark Transformer models.",
|
140 |
+
FutureWarning,
|
141 |
+
)
|
142 |
+
|
143 |
+
def to_json_string(self):
|
144 |
+
"""
|
145 |
+
Serializes this instance to a JSON string.
|
146 |
+
"""
|
147 |
+
return json.dumps(dataclasses.asdict(self), indent=2)
|
148 |
+
|
149 |
+
@property
|
150 |
+
def model_names(self) -> List[str]:
|
151 |
+
if len(self.models) <= 0:
|
152 |
+
raise ValueError(
|
153 |
+
"Please make sure you provide at least one model name / model identifier, *e.g.* `--models"
|
154 |
+
" bert-base-cased` or `args.models = ['bert-base-cased']."
|
155 |
+
)
|
156 |
+
return self.models
|
157 |
+
|
158 |
+
@property
|
159 |
+
def do_multi_processing(self):
|
160 |
+
if not self.multi_process:
|
161 |
+
return False
|
162 |
+
elif self.is_tpu:
|
163 |
+
logger.info("Multiprocessing is currently not possible on TPU.")
|
164 |
+
return False
|
165 |
+
else:
|
166 |
+
return True
|
transformers_4_35_0/benchmark/benchmark_tf.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Benchmarking the library on inference and training in PyTorch.
|
18 |
+
"""
|
19 |
+
|
20 |
+
|
21 |
+
import random
|
22 |
+
import timeit
|
23 |
+
from functools import wraps
|
24 |
+
from typing import Callable, Optional
|
25 |
+
|
26 |
+
from ..configuration_utils import PretrainedConfig
|
27 |
+
from ..models.auto.modeling_tf_auto import TF_MODEL_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING
|
28 |
+
from ..utils import is_py3nvml_available, is_tf_available, logging
|
29 |
+
from .benchmark_utils import (
|
30 |
+
Benchmark,
|
31 |
+
Memory,
|
32 |
+
MemorySummary,
|
33 |
+
measure_peak_memory_cpu,
|
34 |
+
start_memory_tracing,
|
35 |
+
stop_memory_tracing,
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
if is_tf_available():
|
40 |
+
import tensorflow as tf
|
41 |
+
from tensorflow.python.framework.errors_impl import ResourceExhaustedError
|
42 |
+
|
43 |
+
from .benchmark_args_tf import TensorFlowBenchmarkArguments
|
44 |
+
|
45 |
+
if is_py3nvml_available():
|
46 |
+
import py3nvml.py3nvml as nvml
|
47 |
+
|
48 |
+
logger = logging.get_logger(__name__)
|
49 |
+
|
50 |
+
|
51 |
+
def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool):
|
52 |
+
def run_func(func):
|
53 |
+
@wraps(func)
|
54 |
+
def run_in_eager_mode(*args, **kwargs):
|
55 |
+
return func(*args, **kwargs)
|
56 |
+
|
57 |
+
@wraps(func)
|
58 |
+
@tf.function(experimental_compile=use_xla)
|
59 |
+
def run_in_graph_mode(*args, **kwargs):
|
60 |
+
return func(*args, **kwargs)
|
61 |
+
|
62 |
+
if do_eager_mode is True:
|
63 |
+
if use_xla is not False:
|
64 |
+
raise ValueError(
|
65 |
+
"Cannot run model in XLA, if `args.eager_mode` is set to `True`. Please set `args.eager_mode=False`."
|
66 |
+
)
|
67 |
+
return run_in_eager_mode
|
68 |
+
else:
|
69 |
+
return run_in_graph_mode
|
70 |
+
|
71 |
+
return run_func
|
72 |
+
|
73 |
+
|
74 |
+
def random_input_ids(batch_size: int, sequence_length: int, vocab_size: int) -> ["tf.Tensor"]:
|
75 |
+
rng = random.Random()
|
76 |
+
values = [rng.randint(0, vocab_size - 1) for i in range(batch_size * sequence_length)]
|
77 |
+
return tf.constant(values, shape=(batch_size, sequence_length), dtype=tf.int32)
|
78 |
+
|
79 |
+
|
80 |
+
class TensorFlowBenchmark(Benchmark):
|
81 |
+
args: TensorFlowBenchmarkArguments
|
82 |
+
configs: PretrainedConfig
|
83 |
+
framework: str = "TensorFlow"
|
84 |
+
|
85 |
+
@property
|
86 |
+
def framework_version(self):
|
87 |
+
return tf.__version__
|
88 |
+
|
89 |
+
def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
|
90 |
+
# initialize GPU on separate process
|
91 |
+
strategy = self.args.strategy
|
92 |
+
if strategy is None:
|
93 |
+
raise ValueError("A device strategy has to be initialized before using TensorFlow.")
|
94 |
+
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
|
95 |
+
return self._measure_speed(_inference)
|
96 |
+
|
97 |
+
def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
|
98 |
+
strategy = self.args.strategy
|
99 |
+
if strategy is None:
|
100 |
+
raise ValueError("A device strategy has to be initialized before using TensorFlow.")
|
101 |
+
_train = self._prepare_train_func(model_name, batch_size, sequence_length)
|
102 |
+
return self._measure_speed(_train)
|
103 |
+
|
104 |
+
def _inference_memory(
|
105 |
+
self, model_name: str, batch_size: int, sequence_length: int
|
106 |
+
) -> [Memory, Optional[MemorySummary]]:
|
107 |
+
# initialize GPU on separate process
|
108 |
+
if self.args.is_gpu:
|
109 |
+
tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
|
110 |
+
strategy = self.args.strategy
|
111 |
+
if strategy is None:
|
112 |
+
raise ValueError("A device strategy has to be initialized before using TensorFlow.")
|
113 |
+
_inference = self._prepare_inference_func(model_name, batch_size, sequence_length)
|
114 |
+
return self._measure_memory(_inference)
|
115 |
+
|
116 |
+
def _train_memory(
|
117 |
+
self, model_name: str, batch_size: int, sequence_length: int
|
118 |
+
) -> [Memory, Optional[MemorySummary]]:
|
119 |
+
if self.args.is_gpu:
|
120 |
+
tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True)
|
121 |
+
strategy = self.args.strategy
|
122 |
+
if strategy is None:
|
123 |
+
raise ValueError("A device strategy has to be initialized before using TensorFlow.")
|
124 |
+
|
125 |
+
_train = self._prepare_train_func(model_name, batch_size, sequence_length)
|
126 |
+
return self._measure_memory(_train)
|
127 |
+
|
128 |
+
def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
|
129 |
+
config = self.config_dict[model_name]
|
130 |
+
|
131 |
+
if self.args.fp16:
|
132 |
+
raise NotImplementedError("Mixed precision is currently not supported.")
|
133 |
+
|
134 |
+
has_model_class_in_config = (
|
135 |
+
hasattr(config, "architectures")
|
136 |
+
and isinstance(config.architectures, list)
|
137 |
+
and len(config.architectures) > 0
|
138 |
+
)
|
139 |
+
if not self.args.only_pretrain_model and has_model_class_in_config:
|
140 |
+
try:
|
141 |
+
model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
|
142 |
+
transformers_module = __import__("transformers", fromlist=[model_class])
|
143 |
+
model_cls = getattr(transformers_module, model_class)
|
144 |
+
model = model_cls(config)
|
145 |
+
except ImportError:
|
146 |
+
raise ImportError(
|
147 |
+
f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
|
148 |
+
" set `--only_pretrain_model` or `args.only_pretrain_model=True`."
|
149 |
+
)
|
150 |
+
else:
|
151 |
+
model = TF_MODEL_MAPPING[config.__class__](config)
|
152 |
+
|
153 |
+
# encoder-decoder has vocab size saved differently
|
154 |
+
vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
|
155 |
+
input_ids = random_input_ids(batch_size, sequence_length, vocab_size)
|
156 |
+
|
157 |
+
@run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
|
158 |
+
def encoder_decoder_forward():
|
159 |
+
return model(input_ids, decoder_input_ids=input_ids, training=False)
|
160 |
+
|
161 |
+
@run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
|
162 |
+
def encoder_forward():
|
163 |
+
return model(input_ids, training=False)
|
164 |
+
|
165 |
+
_inference = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward
|
166 |
+
|
167 |
+
return _inference
|
168 |
+
|
169 |
+
def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]:
|
170 |
+
config = self.config_dict[model_name]
|
171 |
+
|
172 |
+
if self.args.eager_mode is not False:
|
173 |
+
raise ValueError("Training cannot be done in eager mode. Please make sure that `args.eager_mode = False`.")
|
174 |
+
|
175 |
+
if self.args.fp16:
|
176 |
+
raise NotImplementedError("Mixed precision is currently not supported.")
|
177 |
+
|
178 |
+
has_model_class_in_config = (
|
179 |
+
hasattr(config, "architectures")
|
180 |
+
and isinstance(config.architectures, list)
|
181 |
+
and len(config.architectures) > 0
|
182 |
+
)
|
183 |
+
if not self.args.only_pretrain_model and has_model_class_in_config:
|
184 |
+
try:
|
185 |
+
model_class = "TF" + config.architectures[0] # prepend 'TF' for tensorflow model
|
186 |
+
transformers_module = __import__("transformers", fromlist=[model_class])
|
187 |
+
model_cls = getattr(transformers_module, model_class)
|
188 |
+
model = model_cls(config)
|
189 |
+
except ImportError:
|
190 |
+
raise ImportError(
|
191 |
+
f"{model_class} does not exist. If you just want to test the pretrained model, you might want to"
|
192 |
+
" set `--only_pretrain_model` or `args.only_pretrain_model=True`."
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
model = TF_MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config)
|
196 |
+
|
197 |
+
# encoder-decoder has vocab size saved differently
|
198 |
+
vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size
|
199 |
+
input_ids = random_input_ids(batch_size, sequence_length, vocab_size)
|
200 |
+
|
201 |
+
@run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
|
202 |
+
def encoder_decoder_train():
|
203 |
+
loss = model(input_ids, decoder_input_ids=input_ids, labels=input_ids, training=True)[0]
|
204 |
+
gradients = tf.gradients(loss, model.trainable_variables)
|
205 |
+
return gradients
|
206 |
+
|
207 |
+
@run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla)
|
208 |
+
def encoder_train():
|
209 |
+
loss = model(input_ids, labels=input_ids, training=True)[0]
|
210 |
+
gradients = tf.gradients(loss, model.trainable_variables)
|
211 |
+
return gradients
|
212 |
+
|
213 |
+
_train = encoder_decoder_train if config.is_encoder_decoder else encoder_train
|
214 |
+
|
215 |
+
return _train
|
216 |
+
|
217 |
+
def _measure_speed(self, func) -> float:
|
218 |
+
with self.args.strategy.scope():
|
219 |
+
try:
|
220 |
+
if self.args.is_tpu or self.args.use_xla:
|
221 |
+
# run additional 10 times to stabilize compilation for tpu
|
222 |
+
logger.info("Do inference on TPU. Running model 5 times to stabilize compilation")
|
223 |
+
timeit.repeat(func, repeat=1, number=5)
|
224 |
+
|
225 |
+
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
|
226 |
+
runtimes = timeit.repeat(
|
227 |
+
func,
|
228 |
+
repeat=self.args.repeat,
|
229 |
+
number=10,
|
230 |
+
)
|
231 |
+
|
232 |
+
return min(runtimes) / 10.0
|
233 |
+
except ResourceExhaustedError as e:
|
234 |
+
self.print_fn(f"Doesn't fit on GPU. {e}")
|
235 |
+
|
236 |
+
def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]:
|
237 |
+
logger.info(
|
238 |
+
"Note that TensorFlow allocates more memory than "
|
239 |
+
"it might need to speed up computation. "
|
240 |
+
"The memory reported here corresponds to the memory "
|
241 |
+
"reported by `nvidia-smi`, which can vary depending "
|
242 |
+
"on total available memory on the GPU that is used."
|
243 |
+
)
|
244 |
+
with self.args.strategy.scope():
|
245 |
+
try:
|
246 |
+
if self.args.trace_memory_line_by_line:
|
247 |
+
if not self.args.eager_mode:
|
248 |
+
raise ValueError(
|
249 |
+
"`args.eager_mode` is set to `False`. Make sure to run model in eager mode to measure memory"
|
250 |
+
" consumption line by line."
|
251 |
+
)
|
252 |
+
trace = start_memory_tracing("transformers")
|
253 |
+
|
254 |
+
if self.args.is_tpu:
|
255 |
+
# tpu
|
256 |
+
raise NotImplementedError(
|
257 |
+
"Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking"
|
258 |
+
" with `args.memory=False`"
|
259 |
+
)
|
260 |
+
elif self.args.is_gpu:
|
261 |
+
# gpu
|
262 |
+
if not is_py3nvml_available():
|
263 |
+
logger.warning(
|
264 |
+
"py3nvml not installed, we won't log GPU memory usage. "
|
265 |
+
"Install py3nvml (pip install py3nvml) to log information about GPU."
|
266 |
+
)
|
267 |
+
memory = "N/A"
|
268 |
+
else:
|
269 |
+
logger.info(
|
270 |
+
"Measuring total GPU usage on GPU device. Make sure to not have additional processes"
|
271 |
+
" running on the same GPU."
|
272 |
+
)
|
273 |
+
# init nvml
|
274 |
+
nvml.nvmlInit()
|
275 |
+
func()
|
276 |
+
handle = nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx)
|
277 |
+
meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
|
278 |
+
max_bytes_in_use = meminfo.used
|
279 |
+
memory = Memory(max_bytes_in_use)
|
280 |
+
# shutdown nvml
|
281 |
+
nvml.nvmlShutdown()
|
282 |
+
else:
|
283 |
+
# cpu
|
284 |
+
if self.args.trace_memory_line_by_line:
|
285 |
+
logger.info(
|
286 |
+
"When enabling line by line tracing, the max peak memory for CPU is inaccurate in"
|
287 |
+
" TensorFlow."
|
288 |
+
)
|
289 |
+
memory = None
|
290 |
+
else:
|
291 |
+
memory_bytes = measure_peak_memory_cpu(func)
|
292 |
+
memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes
|
293 |
+
if self.args.trace_memory_line_by_line:
|
294 |
+
summary = stop_memory_tracing(trace)
|
295 |
+
if memory is None:
|
296 |
+
memory = summary.total
|
297 |
+
else:
|
298 |
+
summary = None
|
299 |
+
|
300 |
+
return memory, summary
|
301 |
+
except ResourceExhaustedError as e:
|
302 |
+
self.print_fn(f"Doesn't fit on GPU. {e}")
|
303 |
+
return "N/A", None
|
transformers_4_35_0/benchmark/benchmark_utils.py
ADDED
@@ -0,0 +1,914 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
2 |
+
|
3 |
+
# Copyright 2020 The HuggingFace Team and the AllenNLP authors. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Utilities for working with the local dataset cache.
|
18 |
+
"""
|
19 |
+
|
20 |
+
import copy
|
21 |
+
import csv
|
22 |
+
import linecache
|
23 |
+
import os
|
24 |
+
import platform
|
25 |
+
import sys
|
26 |
+
import warnings
|
27 |
+
from abc import ABC, abstractmethod
|
28 |
+
from collections import defaultdict, namedtuple
|
29 |
+
from datetime import datetime
|
30 |
+
from multiprocessing import Pipe, Process, Queue
|
31 |
+
from multiprocessing.connection import Connection
|
32 |
+
from typing import Callable, Iterable, List, NamedTuple, Optional, Union
|
33 |
+
|
34 |
+
from .. import AutoConfig, PretrainedConfig
|
35 |
+
from .. import __version__ as version
|
36 |
+
from ..utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available, logging
|
37 |
+
from .benchmark_args_utils import BenchmarkArguments
|
38 |
+
|
39 |
+
|
40 |
+
if is_torch_available():
|
41 |
+
from torch.cuda import empty_cache as torch_empty_cache
|
42 |
+
|
43 |
+
if is_tf_available():
|
44 |
+
from tensorflow.python.eager import context as tf_context
|
45 |
+
|
46 |
+
if is_psutil_available():
|
47 |
+
import psutil
|
48 |
+
|
49 |
+
if is_py3nvml_available():
|
50 |
+
import py3nvml.py3nvml as nvml
|
51 |
+
|
52 |
+
if platform.system() == "Windows":
|
53 |
+
from signal import CTRL_C_EVENT as SIGKILL
|
54 |
+
else:
|
55 |
+
from signal import SIGKILL
|
56 |
+
|
57 |
+
|
58 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
59 |
+
|
60 |
+
|
61 |
+
_is_memory_tracing_enabled = False
|
62 |
+
|
63 |
+
BenchmarkOutput = namedtuple(
|
64 |
+
"BenchmarkOutput",
|
65 |
+
[
|
66 |
+
"time_inference_result",
|
67 |
+
"memory_inference_result",
|
68 |
+
"time_train_result",
|
69 |
+
"memory_train_result",
|
70 |
+
"inference_summary",
|
71 |
+
"train_summary",
|
72 |
+
],
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
def separate_process_wrapper_fn(func: Callable[[], None], do_multi_processing: bool) -> Callable[[], None]:
|
77 |
+
"""
|
78 |
+
This function wraps another function into its own separated process. In order to ensure accurate memory
|
79 |
+
measurements it is important that the function is executed in a separate process
|
80 |
+
|
81 |
+
Args:
|
82 |
+
- `func`: (`callable`): function() -> ... generic function which will be executed in its own separate process
|
83 |
+
- `do_multi_processing`: (`bool`) Whether to run function on separate process or not
|
84 |
+
"""
|
85 |
+
|
86 |
+
def multi_process_func(*args, **kwargs):
|
87 |
+
# run function in an individual
|
88 |
+
# process to get correct memory
|
89 |
+
def wrapper_func(queue: Queue, *args):
|
90 |
+
try:
|
91 |
+
result = func(*args)
|
92 |
+
except Exception as e:
|
93 |
+
logger.error(e)
|
94 |
+
print(e)
|
95 |
+
result = "N/A"
|
96 |
+
queue.put(result)
|
97 |
+
|
98 |
+
queue = Queue()
|
99 |
+
p = Process(target=wrapper_func, args=[queue] + list(args))
|
100 |
+
p.start()
|
101 |
+
result = queue.get()
|
102 |
+
p.join()
|
103 |
+
return result
|
104 |
+
|
105 |
+
if do_multi_processing:
|
106 |
+
logger.info(f"Function {func} is executed in its own process...")
|
107 |
+
return multi_process_func
|
108 |
+
else:
|
109 |
+
return func
|
110 |
+
|
111 |
+
|
112 |
+
def is_memory_tracing_enabled():
|
113 |
+
global _is_memory_tracing_enabled
|
114 |
+
return _is_memory_tracing_enabled
|
115 |
+
|
116 |
+
|
117 |
+
class Frame(NamedTuple):
|
118 |
+
"""
|
119 |
+
`Frame` is a NamedTuple used to gather the current frame state. `Frame` has the following fields:
|
120 |
+
|
121 |
+
- 'filename' (string): Name of the file currently executed
|
122 |
+
- 'module' (string): Name of the module currently executed
|
123 |
+
- 'line_number' (int): Number of the line currently executed
|
124 |
+
- 'event' (string): Event that triggered the tracing (default will be "line")
|
125 |
+
- 'line_text' (string): Text of the line in the python script
|
126 |
+
"""
|
127 |
+
|
128 |
+
filename: str
|
129 |
+
module: str
|
130 |
+
line_number: int
|
131 |
+
event: str
|
132 |
+
line_text: str
|
133 |
+
|
134 |
+
|
135 |
+
class UsedMemoryState(NamedTuple):
|
136 |
+
"""
|
137 |
+
`UsedMemoryState` are named tuples with the following fields:
|
138 |
+
|
139 |
+
- 'frame': a `Frame` namedtuple (see below) storing information on the current tracing frame (current file,
|
140 |
+
location in current file)
|
141 |
+
- 'cpu_memory': CPU RSS memory state *before* executing the line
|
142 |
+
- 'gpu_memory': GPU used memory *before* executing the line (sum for all GPUs or for only `gpus_to_trace` if
|
143 |
+
provided)
|
144 |
+
"""
|
145 |
+
|
146 |
+
frame: Frame
|
147 |
+
cpu_memory: int
|
148 |
+
gpu_memory: int
|
149 |
+
|
150 |
+
|
151 |
+
class Memory(NamedTuple):
|
152 |
+
"""
|
153 |
+
`Memory` NamedTuple have a single field `bytes` and you can get a human readable str of the number of mega bytes by
|
154 |
+
calling `__repr__`
|
155 |
+
|
156 |
+
- `byte` (integer): number of bytes,
|
157 |
+
"""
|
158 |
+
|
159 |
+
bytes: int
|
160 |
+
|
161 |
+
def __repr__(self) -> str:
|
162 |
+
return str(bytes_to_mega_bytes(self.bytes))
|
163 |
+
|
164 |
+
|
165 |
+
class MemoryState(NamedTuple):
|
166 |
+
"""
|
167 |
+
`MemoryState` are namedtuples listing frame + CPU/GPU memory with the following fields:
|
168 |
+
|
169 |
+
- `frame` (`Frame`): the current frame (see above)
|
170 |
+
- `cpu`: CPU memory consumed at during the current frame as a `Memory` named tuple
|
171 |
+
- `gpu`: GPU memory consumed at during the current frame as a `Memory` named tuple
|
172 |
+
- `cpu_gpu`: CPU + GPU memory consumed at during the current frame as a `Memory` named tuple
|
173 |
+
"""
|
174 |
+
|
175 |
+
frame: Frame
|
176 |
+
cpu: Memory
|
177 |
+
gpu: Memory
|
178 |
+
cpu_gpu: Memory
|
179 |
+
|
180 |
+
|
181 |
+
class MemorySummary(NamedTuple):
|
182 |
+
"""
|
183 |
+
`MemorySummary` namedtuple otherwise with the fields:
|
184 |
+
|
185 |
+
- `sequential`: a list of `MemoryState` namedtuple (see below) computed from the provided `memory_trace` by
|
186 |
+
subtracting the memory after executing each line from the memory before executing said line.
|
187 |
+
- `cumulative`: a list of `MemoryState` namedtuple (see below) with cumulative increase in memory for each line
|
188 |
+
obtained by summing repeated memory increase for a line if it's executed several times. The list is sorted
|
189 |
+
from the frame with the largest memory consumption to the frame with the smallest (can be negative if memory
|
190 |
+
is released)
|
191 |
+
- `total`: total memory increase during the full tracing as a `Memory` named tuple (see below). Line with
|
192 |
+
memory release (negative consumption) are ignored if `ignore_released_memory` is `True` (default).
|
193 |
+
"""
|
194 |
+
|
195 |
+
sequential: List[MemoryState]
|
196 |
+
cumulative: List[MemoryState]
|
197 |
+
current: List[MemoryState]
|
198 |
+
total: Memory
|
199 |
+
|
200 |
+
|
201 |
+
MemoryTrace = List[UsedMemoryState]
|
202 |
+
|
203 |
+
|
204 |
+
def measure_peak_memory_cpu(function: Callable[[], None], interval=0.5, device_idx=None) -> int:
|
205 |
+
"""
|
206 |
+
measures peak cpu memory consumption of a given `function` running the function for at least interval seconds and
|
207 |
+
at most 20 * interval seconds. This function is heavily inspired by: `memory_usage` of the package
|
208 |
+
`memory_profiler`:
|
209 |
+
https://github.com/pythonprofilers/memory_profiler/blob/895c4ac7a08020d66ae001e24067da6dcea42451/memory_profiler.py#L239
|
210 |
+
|
211 |
+
Args:
|
212 |
+
- `function`: (`callable`): function() -> ... function without any arguments to measure for which to measure
|
213 |
+
the peak memory
|
214 |
+
|
215 |
+
- `interval`: (`float`, `optional`, defaults to `0.5`) interval in second for which to measure the memory usage
|
216 |
+
|
217 |
+
- `device_idx`: (`int`, `optional`, defaults to `None`) device id for which to measure gpu usage
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
|
221 |
+
- `max_memory`: (`int`) consumed memory peak in Bytes
|
222 |
+
"""
|
223 |
+
|
224 |
+
def get_cpu_memory(process_id: int) -> int:
|
225 |
+
"""
|
226 |
+
measures current cpu memory usage of a given `process_id`
|
227 |
+
|
228 |
+
Args:
|
229 |
+
- `process_id`: (`int`) process_id for which to measure memory
|
230 |
+
|
231 |
+
Returns
|
232 |
+
|
233 |
+
- `memory`: (`int`) consumed memory in Bytes
|
234 |
+
"""
|
235 |
+
process = psutil.Process(process_id)
|
236 |
+
try:
|
237 |
+
meminfo_attr = "memory_info" if hasattr(process, "memory_info") else "get_memory_info"
|
238 |
+
memory = getattr(process, meminfo_attr)()[0]
|
239 |
+
except psutil.AccessDenied:
|
240 |
+
raise ValueError("Error with Psutil.")
|
241 |
+
return memory
|
242 |
+
|
243 |
+
if not is_psutil_available():
|
244 |
+
logger.warning(
|
245 |
+
"Psutil not installed, we won't log CPU memory usage. "
|
246 |
+
"Install Psutil (pip install psutil) to use CPU memory tracing."
|
247 |
+
)
|
248 |
+
max_memory = "N/A"
|
249 |
+
else:
|
250 |
+
|
251 |
+
class MemoryMeasureProcess(Process):
|
252 |
+
|
253 |
+
"""
|
254 |
+
`MemoryMeasureProcess` inherits from `Process` and overwrites its `run()` method. Used to measure the
|
255 |
+
memory usage of a process
|
256 |
+
"""
|
257 |
+
|
258 |
+
def __init__(self, process_id: int, child_connection: Connection, interval: float):
|
259 |
+
super().__init__()
|
260 |
+
self.process_id = process_id
|
261 |
+
self.interval = interval
|
262 |
+
self.connection = child_connection
|
263 |
+
self.num_measurements = 1
|
264 |
+
self.mem_usage = get_cpu_memory(self.process_id)
|
265 |
+
|
266 |
+
def run(self):
|
267 |
+
self.connection.send(0)
|
268 |
+
stop = False
|
269 |
+
while True:
|
270 |
+
self.mem_usage = max(self.mem_usage, get_cpu_memory(self.process_id))
|
271 |
+
self.num_measurements += 1
|
272 |
+
|
273 |
+
if stop:
|
274 |
+
break
|
275 |
+
|
276 |
+
stop = self.connection.poll(self.interval)
|
277 |
+
|
278 |
+
# send results to parent pipe
|
279 |
+
self.connection.send(self.mem_usage)
|
280 |
+
self.connection.send(self.num_measurements)
|
281 |
+
|
282 |
+
while True:
|
283 |
+
# create child, parent connection
|
284 |
+
child_connection, parent_connection = Pipe()
|
285 |
+
|
286 |
+
# instantiate process
|
287 |
+
mem_process = MemoryMeasureProcess(os.getpid(), child_connection, interval)
|
288 |
+
mem_process.start()
|
289 |
+
|
290 |
+
# wait until we get memory
|
291 |
+
parent_connection.recv()
|
292 |
+
|
293 |
+
try:
|
294 |
+
# execute function
|
295 |
+
function()
|
296 |
+
|
297 |
+
# start parent connection
|
298 |
+
parent_connection.send(0)
|
299 |
+
|
300 |
+
# receive memory and num measurements
|
301 |
+
max_memory = parent_connection.recv()
|
302 |
+
num_measurements = parent_connection.recv()
|
303 |
+
except Exception:
|
304 |
+
# kill process in a clean way
|
305 |
+
parent = psutil.Process(os.getpid())
|
306 |
+
for child in parent.children(recursive=True):
|
307 |
+
os.kill(child.pid, SIGKILL)
|
308 |
+
mem_process.join(0)
|
309 |
+
raise RuntimeError("Process killed. Error in Process")
|
310 |
+
|
311 |
+
# run process at least 20 * interval or until it finishes
|
312 |
+
mem_process.join(20 * interval)
|
313 |
+
|
314 |
+
if (num_measurements > 4) or (interval < 1e-6):
|
315 |
+
break
|
316 |
+
|
317 |
+
# reduce interval
|
318 |
+
interval /= 10
|
319 |
+
|
320 |
+
return max_memory
|
321 |
+
|
322 |
+
|
323 |
+
def start_memory_tracing(
|
324 |
+
modules_to_trace: Optional[Union[str, Iterable[str]]] = None,
|
325 |
+
modules_not_to_trace: Optional[Union[str, Iterable[str]]] = None,
|
326 |
+
events_to_trace: str = "line",
|
327 |
+
gpus_to_trace: Optional[List[int]] = None,
|
328 |
+
) -> MemoryTrace:
|
329 |
+
"""
|
330 |
+
Setup line-by-line tracing to record rss mem (RAM) at each line of a module or sub-module. See `./benchmark.py` for
|
331 |
+
usage examples. Current memory consumption is returned using psutil and in particular is the RSS memory "Resident
|
332 |
+
Set Size” (the non-swapped physical memory the process is using). See
|
333 |
+
https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info
|
334 |
+
|
335 |
+
Args:
|
336 |
+
- `modules_to_trace`: (None, string, list/tuple of string) if None, all events are recorded if string or list
|
337 |
+
of strings: only events from the listed module/sub-module will be recorded (e.g. 'fairseq' or
|
338 |
+
'transformers.models.gpt2.modeling_gpt2')
|
339 |
+
- `modules_not_to_trace`: (None, string, list/tuple of string) if None, no module is avoided if string or list
|
340 |
+
of strings: events from the listed module/sub-module will not be recorded (e.g. 'torch')
|
341 |
+
- `events_to_trace`: string or list of string of events to be recorded (see official python doc for
|
342 |
+
`sys.settrace` for the list of events) default to line
|
343 |
+
- `gpus_to_trace`: (optional list, default None) list of GPUs to trace. Default to tracing all GPUs
|
344 |
+
|
345 |
+
Return:
|
346 |
+
|
347 |
+
- `memory_trace` is a list of `UsedMemoryState` for each event (default each line of the traced script).
|
348 |
+
|
349 |
+
- `UsedMemoryState` are named tuples with the following fields:
|
350 |
+
|
351 |
+
- 'frame': a `Frame` namedtuple (see below) storing information on the current tracing frame (current
|
352 |
+
file, location in current file)
|
353 |
+
- 'cpu_memory': CPU RSS memory state *before* executing the line
|
354 |
+
- 'gpu_memory': GPU used memory *before* executing the line (sum for all GPUs or for only
|
355 |
+
`gpus_to_trace` if provided)
|
356 |
+
|
357 |
+
`Frame` is a namedtuple used by `UsedMemoryState` to list the current frame state. `Frame` has the following
|
358 |
+
fields: - 'filename' (string): Name of the file currently executed - 'module' (string): Name of the module
|
359 |
+
currently executed - 'line_number' (int): Number of the line currently executed - 'event' (string): Event that
|
360 |
+
triggered the tracing (default will be "line") - 'line_text' (string): Text of the line in the python script
|
361 |
+
|
362 |
+
"""
|
363 |
+
if is_psutil_available():
|
364 |
+
process = psutil.Process(os.getpid())
|
365 |
+
else:
|
366 |
+
logger.warning(
|
367 |
+
"Psutil not installed, we won't log CPU memory usage. "
|
368 |
+
"Install psutil (pip install psutil) to use CPU memory tracing."
|
369 |
+
)
|
370 |
+
process = None
|
371 |
+
|
372 |
+
if is_py3nvml_available():
|
373 |
+
try:
|
374 |
+
nvml.nvmlInit()
|
375 |
+
devices = list(range(nvml.nvmlDeviceGetCount())) if gpus_to_trace is None else gpus_to_trace
|
376 |
+
nvml.nvmlShutdown()
|
377 |
+
except (OSError, nvml.NVMLError):
|
378 |
+
logger.warning("Error while initializing communication with GPU. We won't perform GPU memory tracing.")
|
379 |
+
log_gpu = False
|
380 |
+
else:
|
381 |
+
log_gpu = is_torch_available() or is_tf_available()
|
382 |
+
else:
|
383 |
+
logger.warning(
|
384 |
+
"py3nvml not installed, we won't log GPU memory usage. "
|
385 |
+
"Install py3nvml (pip install py3nvml) to use GPU memory tracing."
|
386 |
+
)
|
387 |
+
log_gpu = False
|
388 |
+
|
389 |
+
memory_trace = []
|
390 |
+
|
391 |
+
def traceit(frame, event, args):
|
392 |
+
"""
|
393 |
+
Tracing method executed before running each line in a module or sub-module Record memory allocated in a list
|
394 |
+
with debugging information
|
395 |
+
"""
|
396 |
+
global _is_memory_tracing_enabled
|
397 |
+
|
398 |
+
if not _is_memory_tracing_enabled:
|
399 |
+
return traceit
|
400 |
+
|
401 |
+
# Filter events
|
402 |
+
if events_to_trace is not None:
|
403 |
+
if isinstance(events_to_trace, str) and event != events_to_trace:
|
404 |
+
return traceit
|
405 |
+
elif isinstance(events_to_trace, (list, tuple)) and event not in events_to_trace:
|
406 |
+
return traceit
|
407 |
+
|
408 |
+
if "__name__" not in frame.f_globals:
|
409 |
+
return traceit
|
410 |
+
|
411 |
+
# Filter modules
|
412 |
+
name = frame.f_globals["__name__"]
|
413 |
+
if not isinstance(name, str):
|
414 |
+
return traceit
|
415 |
+
else:
|
416 |
+
# Filter whitelist of modules to trace
|
417 |
+
if modules_to_trace is not None:
|
418 |
+
if isinstance(modules_to_trace, str) and modules_to_trace not in name:
|
419 |
+
return traceit
|
420 |
+
elif isinstance(modules_to_trace, (list, tuple)) and all(m not in name for m in modules_to_trace):
|
421 |
+
return traceit
|
422 |
+
|
423 |
+
# Filter blacklist of modules not to trace
|
424 |
+
if modules_not_to_trace is not None:
|
425 |
+
if isinstance(modules_not_to_trace, str) and modules_not_to_trace in name:
|
426 |
+
return traceit
|
427 |
+
elif isinstance(modules_not_to_trace, (list, tuple)) and any(m in name for m in modules_not_to_trace):
|
428 |
+
return traceit
|
429 |
+
|
430 |
+
# Record current tracing state (file, location in file...)
|
431 |
+
lineno = frame.f_lineno
|
432 |
+
filename = frame.f_globals["__file__"]
|
433 |
+
if filename.endswith(".pyc") or filename.endswith(".pyo"):
|
434 |
+
filename = filename[:-1]
|
435 |
+
line = linecache.getline(filename, lineno).rstrip()
|
436 |
+
traced_state = Frame(filename, name, lineno, event, line)
|
437 |
+
|
438 |
+
# Record current memory state (rss memory) and compute difference with previous memory state
|
439 |
+
cpu_mem = 0
|
440 |
+
if process is not None:
|
441 |
+
mem = process.memory_info()
|
442 |
+
cpu_mem = mem.rss
|
443 |
+
|
444 |
+
gpu_mem = 0
|
445 |
+
if log_gpu:
|
446 |
+
# Clear GPU caches
|
447 |
+
if is_torch_available():
|
448 |
+
torch_empty_cache()
|
449 |
+
if is_tf_available():
|
450 |
+
tf_context.context()._clear_caches() # See https://github.com/tensorflow/tensorflow/issues/20218#issuecomment-416771802
|
451 |
+
|
452 |
+
# Sum used memory for all GPUs
|
453 |
+
nvml.nvmlInit()
|
454 |
+
|
455 |
+
for i in devices:
|
456 |
+
handle = nvml.nvmlDeviceGetHandleByIndex(i)
|
457 |
+
meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
|
458 |
+
gpu_mem += meminfo.used
|
459 |
+
|
460 |
+
nvml.nvmlShutdown()
|
461 |
+
|
462 |
+
mem_state = UsedMemoryState(traced_state, cpu_mem, gpu_mem)
|
463 |
+
memory_trace.append(mem_state)
|
464 |
+
|
465 |
+
return traceit
|
466 |
+
|
467 |
+
sys.settrace(traceit)
|
468 |
+
|
469 |
+
global _is_memory_tracing_enabled
|
470 |
+
_is_memory_tracing_enabled = True
|
471 |
+
|
472 |
+
return memory_trace
|
473 |
+
|
474 |
+
|
475 |
+
def stop_memory_tracing(
|
476 |
+
memory_trace: Optional[MemoryTrace] = None, ignore_released_memory: bool = True
|
477 |
+
) -> Optional[MemorySummary]:
|
478 |
+
"""
|
479 |
+
Stop memory tracing cleanly and return a summary of the memory trace if a trace is given.
|
480 |
+
|
481 |
+
Args:
|
482 |
+
`memory_trace` (optional output of start_memory_tracing, default: None):
|
483 |
+
memory trace to convert in summary
|
484 |
+
`ignore_released_memory` (boolean, default: None):
|
485 |
+
if True we only sum memory increase to compute total memory
|
486 |
+
|
487 |
+
Return:
|
488 |
+
|
489 |
+
- None if `memory_trace` is None
|
490 |
+
- `MemorySummary` namedtuple otherwise with the fields:
|
491 |
+
|
492 |
+
- `sequential`: a list of `MemoryState` namedtuple (see below) computed from the provided `memory_trace` by
|
493 |
+
subtracting the memory after executing each line from the memory before executing said line.
|
494 |
+
- `cumulative`: a list of `MemoryState` namedtuple (see below) with cumulative increase in memory for each
|
495 |
+
line obtained by summing repeated memory increase for a line if it's executed several times. The list is
|
496 |
+
sorted from the frame with the largest memory consumption to the frame with the smallest (can be negative
|
497 |
+
if memory is released)
|
498 |
+
- `total`: total memory increase during the full tracing as a `Memory` named tuple (see below). Line with
|
499 |
+
memory release (negative consumption) are ignored if `ignore_released_memory` is `True` (default).
|
500 |
+
|
501 |
+
`Memory` named tuple have fields
|
502 |
+
|
503 |
+
- `byte` (integer): number of bytes,
|
504 |
+
- `string` (string): same as human readable string (ex: "3.5MB")
|
505 |
+
|
506 |
+
`Frame` are namedtuple used to list the current frame state and have the following fields:
|
507 |
+
|
508 |
+
- 'filename' (string): Name of the file currently executed
|
509 |
+
- 'module' (string): Name of the module currently executed
|
510 |
+
- 'line_number' (int): Number of the line currently executed
|
511 |
+
- 'event' (string): Event that triggered the tracing (default will be "line")
|
512 |
+
- 'line_text' (string): Text of the line in the python script
|
513 |
+
|
514 |
+
`MemoryState` are namedtuples listing frame + CPU/GPU memory with the following fields:
|
515 |
+
|
516 |
+
- `frame` (`Frame`): the current frame (see above)
|
517 |
+
- `cpu`: CPU memory consumed at during the current frame as a `Memory` named tuple
|
518 |
+
- `gpu`: GPU memory consumed at during the current frame as a `Memory` named tuple
|
519 |
+
- `cpu_gpu`: CPU + GPU memory consumed at during the current frame as a `Memory` named tuple
|
520 |
+
"""
|
521 |
+
global _is_memory_tracing_enabled
|
522 |
+
_is_memory_tracing_enabled = False
|
523 |
+
|
524 |
+
if memory_trace is not None and len(memory_trace) > 1:
|
525 |
+
memory_diff_trace = []
|
526 |
+
memory_curr_trace = []
|
527 |
+
|
528 |
+
cumulative_memory_dict = defaultdict(lambda: [0, 0, 0])
|
529 |
+
|
530 |
+
for (
|
531 |
+
(frame, cpu_mem, gpu_mem),
|
532 |
+
(next_frame, next_cpu_mem, next_gpu_mem),
|
533 |
+
) in zip(memory_trace[:-1], memory_trace[1:]):
|
534 |
+
cpu_mem_inc = next_cpu_mem - cpu_mem
|
535 |
+
gpu_mem_inc = next_gpu_mem - gpu_mem
|
536 |
+
cpu_gpu_mem_inc = cpu_mem_inc + gpu_mem_inc
|
537 |
+
memory_diff_trace.append(
|
538 |
+
MemoryState(
|
539 |
+
frame=frame,
|
540 |
+
cpu=Memory(cpu_mem_inc),
|
541 |
+
gpu=Memory(gpu_mem_inc),
|
542 |
+
cpu_gpu=Memory(cpu_gpu_mem_inc),
|
543 |
+
)
|
544 |
+
)
|
545 |
+
|
546 |
+
memory_curr_trace.append(
|
547 |
+
MemoryState(
|
548 |
+
frame=frame,
|
549 |
+
cpu=Memory(next_cpu_mem),
|
550 |
+
gpu=Memory(next_gpu_mem),
|
551 |
+
cpu_gpu=Memory(next_gpu_mem + next_cpu_mem),
|
552 |
+
)
|
553 |
+
)
|
554 |
+
|
555 |
+
cumulative_memory_dict[frame][0] += cpu_mem_inc
|
556 |
+
cumulative_memory_dict[frame][1] += gpu_mem_inc
|
557 |
+
cumulative_memory_dict[frame][2] += cpu_gpu_mem_inc
|
558 |
+
|
559 |
+
cumulative_memory = sorted(
|
560 |
+
cumulative_memory_dict.items(), key=lambda x: x[1][2], reverse=True
|
561 |
+
) # order by the total CPU + GPU memory increase
|
562 |
+
cumulative_memory = [
|
563 |
+
MemoryState(
|
564 |
+
frame=frame,
|
565 |
+
cpu=Memory(cpu_mem_inc),
|
566 |
+
gpu=Memory(gpu_mem_inc),
|
567 |
+
cpu_gpu=Memory(cpu_gpu_mem_inc),
|
568 |
+
)
|
569 |
+
for frame, (cpu_mem_inc, gpu_mem_inc, cpu_gpu_mem_inc) in cumulative_memory
|
570 |
+
]
|
571 |
+
|
572 |
+
memory_curr_trace = sorted(memory_curr_trace, key=lambda x: x.cpu_gpu.bytes, reverse=True)
|
573 |
+
|
574 |
+
if ignore_released_memory:
|
575 |
+
total_memory = sum(max(0, step_trace.cpu_gpu.bytes) for step_trace in memory_diff_trace)
|
576 |
+
else:
|
577 |
+
total_memory = sum(step_trace.cpu_gpu.bytes for step_trace in memory_diff_trace)
|
578 |
+
|
579 |
+
total_memory = Memory(total_memory)
|
580 |
+
|
581 |
+
return MemorySummary(
|
582 |
+
sequential=memory_diff_trace,
|
583 |
+
cumulative=cumulative_memory,
|
584 |
+
current=memory_curr_trace,
|
585 |
+
total=total_memory,
|
586 |
+
)
|
587 |
+
|
588 |
+
return None
|
589 |
+
|
590 |
+
|
591 |
+
def bytes_to_mega_bytes(memory_amount: int) -> int:
|
592 |
+
"""Utility to convert a number of bytes (int) into a number of mega bytes (int)"""
|
593 |
+
return memory_amount >> 20
|
594 |
+
|
595 |
+
|
596 |
+
class Benchmark(ABC):
|
597 |
+
"""
|
598 |
+
Benchmarks is a simple but feature-complete benchmarking script to compare memory and time performance of models in
|
599 |
+
Transformers.
|
600 |
+
"""
|
601 |
+
|
602 |
+
args: BenchmarkArguments
|
603 |
+
configs: PretrainedConfig
|
604 |
+
framework: str
|
605 |
+
|
606 |
+
def __init__(self, args: BenchmarkArguments = None, configs: PretrainedConfig = None):
|
607 |
+
self.args = args
|
608 |
+
if configs is None:
|
609 |
+
self.config_dict = {
|
610 |
+
model_name: AutoConfig.from_pretrained(model_name) for model_name in self.args.model_names
|
611 |
+
}
|
612 |
+
else:
|
613 |
+
self.config_dict = dict(zip(self.args.model_names, configs))
|
614 |
+
|
615 |
+
warnings.warn(
|
616 |
+
f"The class {self.__class__} is deprecated. Hugging Face Benchmarking utils"
|
617 |
+
" are deprecated in general and it is advised to use external Benchmarking libraries "
|
618 |
+
" to benchmark Transformer models.",
|
619 |
+
FutureWarning,
|
620 |
+
)
|
621 |
+
|
622 |
+
if self.args.memory and os.getenv("TRANSFORMERS_USE_MULTIPROCESSING") == 0:
|
623 |
+
logger.warning(
|
624 |
+
"Memory consumption will not be measured accurately if `args.multi_process` is set to `False.` The"
|
625 |
+
" flag 'TRANSFORMERS_USE_MULTIPROCESSING' should only be disabled for debugging / testing."
|
626 |
+
)
|
627 |
+
|
628 |
+
self._print_fn = None
|
629 |
+
self._framework_version = None
|
630 |
+
self._environment_info = None
|
631 |
+
|
632 |
+
@property
|
633 |
+
def print_fn(self):
|
634 |
+
if self._print_fn is None:
|
635 |
+
if self.args.log_print:
|
636 |
+
|
637 |
+
def print_and_log(*args):
|
638 |
+
with open(self.args.log_filename, "a") as log_file:
|
639 |
+
log_file.write("".join(args) + "\n")
|
640 |
+
print(*args)
|
641 |
+
|
642 |
+
self._print_fn = print_and_log
|
643 |
+
else:
|
644 |
+
self._print_fn = print
|
645 |
+
return self._print_fn
|
646 |
+
|
647 |
+
@property
|
648 |
+
@abstractmethod
|
649 |
+
def framework_version(self):
|
650 |
+
pass
|
651 |
+
|
652 |
+
@abstractmethod
|
653 |
+
def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
|
654 |
+
pass
|
655 |
+
|
656 |
+
@abstractmethod
|
657 |
+
def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
|
658 |
+
pass
|
659 |
+
|
660 |
+
@abstractmethod
|
661 |
+
def _inference_memory(
|
662 |
+
self, model_name: str, batch_size: int, sequence_length: int
|
663 |
+
) -> [Memory, Optional[MemorySummary]]:
|
664 |
+
pass
|
665 |
+
|
666 |
+
@abstractmethod
|
667 |
+
def _train_memory(
|
668 |
+
self, model_name: str, batch_size: int, sequence_length: int
|
669 |
+
) -> [Memory, Optional[MemorySummary]]:
|
670 |
+
pass
|
671 |
+
|
672 |
+
def inference_speed(self, *args, **kwargs) -> float:
|
673 |
+
return separate_process_wrapper_fn(self._inference_speed, self.args.do_multi_processing)(*args, **kwargs)
|
674 |
+
|
675 |
+
def train_speed(self, *args, **kwargs) -> float:
|
676 |
+
return separate_process_wrapper_fn(self._train_speed, self.args.do_multi_processing)(*args, **kwargs)
|
677 |
+
|
678 |
+
def inference_memory(self, *args, **kwargs) -> [Memory, Optional[MemorySummary]]:
|
679 |
+
return separate_process_wrapper_fn(self._inference_memory, self.args.do_multi_processing)(*args, **kwargs)
|
680 |
+
|
681 |
+
def train_memory(self, *args, **kwargs) -> [Memory, Optional[MemorySummary]]:
|
682 |
+
return separate_process_wrapper_fn(self._train_memory, self.args.do_multi_processing)(*args, **kwargs)
|
683 |
+
|
684 |
+
def run(self):
|
685 |
+
result_dict = {model_name: {} for model_name in self.args.model_names}
|
686 |
+
inference_result_time = copy.deepcopy(result_dict)
|
687 |
+
inference_result_memory = copy.deepcopy(result_dict)
|
688 |
+
train_result_time = copy.deepcopy(result_dict)
|
689 |
+
train_result_memory = copy.deepcopy(result_dict)
|
690 |
+
|
691 |
+
for c, model_name in enumerate(self.args.model_names):
|
692 |
+
self.print_fn(f"{c + 1} / {len(self.args.model_names)}")
|
693 |
+
|
694 |
+
model_dict = {
|
695 |
+
"bs": self.args.batch_sizes,
|
696 |
+
"ss": self.args.sequence_lengths,
|
697 |
+
"result": {i: {} for i in self.args.batch_sizes},
|
698 |
+
}
|
699 |
+
inference_result_time[model_name] = copy.deepcopy(model_dict)
|
700 |
+
inference_result_memory[model_name] = copy.deepcopy(model_dict)
|
701 |
+
train_result_time[model_name] = copy.deepcopy(model_dict)
|
702 |
+
train_result_memory[model_name] = copy.deepcopy(model_dict)
|
703 |
+
|
704 |
+
inference_summary = train_summary = None
|
705 |
+
|
706 |
+
for batch_size in self.args.batch_sizes:
|
707 |
+
for sequence_length in self.args.sequence_lengths:
|
708 |
+
if self.args.inference:
|
709 |
+
if self.args.memory:
|
710 |
+
memory, inference_summary = self.inference_memory(model_name, batch_size, sequence_length)
|
711 |
+
inference_result_memory[model_name]["result"][batch_size][sequence_length] = memory
|
712 |
+
if self.args.speed:
|
713 |
+
time = self.inference_speed(model_name, batch_size, sequence_length)
|
714 |
+
inference_result_time[model_name]["result"][batch_size][sequence_length] = time
|
715 |
+
|
716 |
+
if self.args.training:
|
717 |
+
if self.args.memory:
|
718 |
+
memory, train_summary = self.train_memory(model_name, batch_size, sequence_length)
|
719 |
+
train_result_memory[model_name]["result"][batch_size][sequence_length] = memory
|
720 |
+
if self.args.speed:
|
721 |
+
time = self.train_speed(model_name, batch_size, sequence_length)
|
722 |
+
train_result_time[model_name]["result"][batch_size][sequence_length] = time
|
723 |
+
|
724 |
+
if self.args.inference:
|
725 |
+
if self.args.speed:
|
726 |
+
self.print_fn("\n" + 20 * "=" + ("INFERENCE - SPEED - RESULT").center(40) + 20 * "=")
|
727 |
+
self.print_results(inference_result_time, type_label="Time in s")
|
728 |
+
self.save_to_csv(inference_result_time, self.args.inference_time_csv_file)
|
729 |
+
if self.args.is_tpu:
|
730 |
+
self.print_fn(
|
731 |
+
"TPU was used for inference. Note that the time after compilation stabilized (after ~10"
|
732 |
+
" inferences model.forward(..) calls) was measured."
|
733 |
+
)
|
734 |
+
|
735 |
+
if self.args.memory:
|
736 |
+
self.print_fn("\n" + 20 * "=" + ("INFERENCE - MEMORY - RESULT").center(40) + 20 * "=")
|
737 |
+
self.print_results(inference_result_memory, type_label="Memory in MB")
|
738 |
+
self.save_to_csv(inference_result_memory, self.args.inference_memory_csv_file)
|
739 |
+
|
740 |
+
if self.args.trace_memory_line_by_line:
|
741 |
+
self.print_fn("\n" + 20 * "=" + ("INFERENCE - MEMOMRY - LINE BY LINE - SUMMARY").center(40) + 20 * "=")
|
742 |
+
self.print_memory_trace_statistics(inference_summary)
|
743 |
+
|
744 |
+
if self.args.training:
|
745 |
+
if self.args.speed:
|
746 |
+
self.print_fn("\n" + 20 * "=" + ("TRAIN - SPEED - RESULTS").center(40) + 20 * "=")
|
747 |
+
self.print_results(train_result_time, "Time in s")
|
748 |
+
self.save_to_csv(train_result_time, self.args.train_time_csv_file)
|
749 |
+
if self.args.is_tpu:
|
750 |
+
self.print_fn(
|
751 |
+
"TPU was used for training. Note that the time after compilation stabilized (after ~10 train"
|
752 |
+
" loss=model.forward(...) + loss.backward() calls) was measured."
|
753 |
+
)
|
754 |
+
|
755 |
+
if self.args.memory:
|
756 |
+
self.print_fn("\n" + 20 * "=" + ("TRAIN - MEMORY - RESULTS").center(40) + 20 * "=")
|
757 |
+
self.print_results(train_result_memory, type_label="Memory in MB")
|
758 |
+
self.save_to_csv(train_result_memory, self.args.train_memory_csv_file)
|
759 |
+
|
760 |
+
if self.args.trace_memory_line_by_line:
|
761 |
+
self.print_fn("\n" + 20 * "=" + ("TRAIN - MEMOMRY - LINE BY LINE - SUMMARY").center(40) + 20 * "=")
|
762 |
+
self.print_memory_trace_statistics(train_summary)
|
763 |
+
|
764 |
+
if self.args.env_print:
|
765 |
+
self.print_fn("\n" + 20 * "=" + ("ENVIRONMENT INFORMATION").center(40) + 20 * "=")
|
766 |
+
self.print_fn("\n".join([f"- {prop}: {val}" for prop, val in self.environment_info.items()]) + "\n")
|
767 |
+
|
768 |
+
if self.args.save_to_csv:
|
769 |
+
with open(self.args.env_info_csv_file, mode="w", newline="") as csv_file:
|
770 |
+
writer = csv.writer(csv_file)
|
771 |
+
for key, value in self.environment_info.items():
|
772 |
+
writer.writerow([key, value])
|
773 |
+
|
774 |
+
return BenchmarkOutput(
|
775 |
+
inference_result_time,
|
776 |
+
inference_result_memory,
|
777 |
+
train_result_time,
|
778 |
+
train_result_memory,
|
779 |
+
inference_summary,
|
780 |
+
train_summary,
|
781 |
+
)
|
782 |
+
|
783 |
+
@property
|
784 |
+
def environment_info(self):
|
785 |
+
if self._environment_info is None:
|
786 |
+
info = {}
|
787 |
+
info["transformers_version"] = version
|
788 |
+
info["framework"] = self.framework
|
789 |
+
if self.framework == "PyTorch":
|
790 |
+
info["use_torchscript"] = self.args.torchscript
|
791 |
+
if self.framework == "TensorFlow":
|
792 |
+
info["eager_mode"] = self.args.eager_mode
|
793 |
+
info["use_xla"] = self.args.use_xla
|
794 |
+
info["framework_version"] = self.framework_version
|
795 |
+
info["python_version"] = platform.python_version()
|
796 |
+
info["system"] = platform.system()
|
797 |
+
info["cpu"] = platform.processor()
|
798 |
+
info["architecture"] = platform.architecture()[0]
|
799 |
+
info["date"] = datetime.date(datetime.now())
|
800 |
+
info["time"] = datetime.time(datetime.now())
|
801 |
+
info["fp16"] = self.args.fp16
|
802 |
+
info["use_multiprocessing"] = self.args.do_multi_processing
|
803 |
+
info["only_pretrain_model"] = self.args.only_pretrain_model
|
804 |
+
|
805 |
+
if is_psutil_available():
|
806 |
+
info["cpu_ram_mb"] = bytes_to_mega_bytes(psutil.virtual_memory().total)
|
807 |
+
else:
|
808 |
+
logger.warning(
|
809 |
+
"Psutil not installed, we won't log available CPU memory. "
|
810 |
+
"Install psutil (pip install psutil) to log available CPU memory."
|
811 |
+
)
|
812 |
+
info["cpu_ram_mb"] = "N/A"
|
813 |
+
|
814 |
+
info["use_gpu"] = self.args.is_gpu
|
815 |
+
if self.args.is_gpu:
|
816 |
+
info["num_gpus"] = 1 # TODO(PVP) Currently only single GPU is supported
|
817 |
+
if is_py3nvml_available():
|
818 |
+
nvml.nvmlInit()
|
819 |
+
handle = nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx)
|
820 |
+
info["gpu"] = nvml.nvmlDeviceGetName(handle)
|
821 |
+
info["gpu_ram_mb"] = bytes_to_mega_bytes(nvml.nvmlDeviceGetMemoryInfo(handle).total)
|
822 |
+
info["gpu_power_watts"] = nvml.nvmlDeviceGetPowerManagementLimit(handle) / 1000
|
823 |
+
info["gpu_performance_state"] = nvml.nvmlDeviceGetPerformanceState(handle)
|
824 |
+
nvml.nvmlShutdown()
|
825 |
+
else:
|
826 |
+
logger.warning(
|
827 |
+
"py3nvml not installed, we won't log GPU memory usage. "
|
828 |
+
"Install py3nvml (pip install py3nvml) to log information about GPU."
|
829 |
+
)
|
830 |
+
info["gpu"] = "N/A"
|
831 |
+
info["gpu_ram_mb"] = "N/A"
|
832 |
+
info["gpu_power_watts"] = "N/A"
|
833 |
+
info["gpu_performance_state"] = "N/A"
|
834 |
+
|
835 |
+
info["use_tpu"] = self.args.is_tpu
|
836 |
+
# TODO(PVP): See if we can add more information about TPU
|
837 |
+
# see: https://github.com/pytorch/xla/issues/2180
|
838 |
+
|
839 |
+
self._environment_info = info
|
840 |
+
return self._environment_info
|
841 |
+
|
842 |
+
def print_results(self, result_dict, type_label):
|
843 |
+
self.print_fn(80 * "-")
|
844 |
+
self.print_fn(
|
845 |
+
"Model Name".center(30) + "Batch Size".center(15) + "Seq Length".center(15) + type_label.center(15)
|
846 |
+
)
|
847 |
+
self.print_fn(80 * "-")
|
848 |
+
for model_name in self.args.model_names:
|
849 |
+
for batch_size in result_dict[model_name]["bs"]:
|
850 |
+
for sequence_length in result_dict[model_name]["ss"]:
|
851 |
+
result = result_dict[model_name]["result"][batch_size][sequence_length]
|
852 |
+
if isinstance(result, float):
|
853 |
+
result = round(1000 * result) / 1000
|
854 |
+
result = "< 0.001" if result == 0.0 else str(result)
|
855 |
+
else:
|
856 |
+
result = str(result)
|
857 |
+
self.print_fn(
|
858 |
+
model_name[:30].center(30) + str(batch_size).center(15),
|
859 |
+
str(sequence_length).center(15),
|
860 |
+
result.center(15),
|
861 |
+
)
|
862 |
+
self.print_fn(80 * "-")
|
863 |
+
|
864 |
+
def print_memory_trace_statistics(self, summary: MemorySummary):
|
865 |
+
self.print_fn(
|
866 |
+
"\nLine by line memory consumption:\n"
|
867 |
+
+ "\n".join(
|
868 |
+
f"{state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
|
869 |
+
for state in summary.sequential
|
870 |
+
)
|
871 |
+
)
|
872 |
+
self.print_fn(
|
873 |
+
"\nLines with top memory consumption:\n"
|
874 |
+
+ "\n".join(
|
875 |
+
f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
|
876 |
+
for state in summary.cumulative[:6]
|
877 |
+
)
|
878 |
+
)
|
879 |
+
self.print_fn(
|
880 |
+
"\nLines with lowest memory consumption:\n"
|
881 |
+
+ "\n".join(
|
882 |
+
f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
|
883 |
+
for state in summary.cumulative[-6:]
|
884 |
+
)
|
885 |
+
)
|
886 |
+
self.print_fn(f"\nTotal memory increase: {summary.total}")
|
887 |
+
|
888 |
+
def save_to_csv(self, result_dict, filename):
|
889 |
+
if not self.args.save_to_csv:
|
890 |
+
return
|
891 |
+
self.print_fn("Saving results to csv.")
|
892 |
+
with open(filename, mode="w") as csv_file:
|
893 |
+
if len(self.args.model_names) <= 0:
|
894 |
+
raise ValueError(f"At least 1 model should be defined, but got {self.model_names}")
|
895 |
+
|
896 |
+
fieldnames = ["model", "batch_size", "sequence_length"]
|
897 |
+
writer = csv.DictWriter(csv_file, fieldnames=fieldnames + ["result"])
|
898 |
+
writer.writeheader()
|
899 |
+
|
900 |
+
for model_name in self.args.model_names:
|
901 |
+
result_dict_model = result_dict[model_name]["result"]
|
902 |
+
for bs in result_dict_model:
|
903 |
+
for ss in result_dict_model[bs]:
|
904 |
+
result_model = result_dict_model[bs][ss]
|
905 |
+
writer.writerow(
|
906 |
+
{
|
907 |
+
"model": model_name,
|
908 |
+
"batch_size": bs,
|
909 |
+
"sequence_length": ss,
|
910 |
+
"result": ("{}" if not isinstance(result_model, float) else "{:.4f}").format(
|
911 |
+
result_model
|
912 |
+
),
|
913 |
+
}
|
914 |
+
)
|
transformers_4_35_0/commands/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from abc import ABC, abstractmethod
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
|
18 |
+
|
19 |
+
class BaseTransformersCLICommand(ABC):
|
20 |
+
@staticmethod
|
21 |
+
@abstractmethod
|
22 |
+
def register_subcommand(parser: ArgumentParser):
|
23 |
+
raise NotImplementedError()
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def run(self):
|
27 |
+
raise NotImplementedError()
|
transformers_4_35_0/commands/add_new_model.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
import shutil
|
18 |
+
import warnings
|
19 |
+
from argparse import ArgumentParser, Namespace
|
20 |
+
from pathlib import Path
|
21 |
+
from typing import List
|
22 |
+
|
23 |
+
from ..utils import logging
|
24 |
+
from . import BaseTransformersCLICommand
|
25 |
+
|
26 |
+
|
27 |
+
try:
|
28 |
+
from cookiecutter.main import cookiecutter
|
29 |
+
|
30 |
+
_has_cookiecutter = True
|
31 |
+
except ImportError:
|
32 |
+
_has_cookiecutter = False
|
33 |
+
|
34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
35 |
+
|
36 |
+
|
37 |
+
def add_new_model_command_factory(args: Namespace):
|
38 |
+
return AddNewModelCommand(args.testing, args.testing_file, path=args.path)
|
39 |
+
|
40 |
+
|
41 |
+
class AddNewModelCommand(BaseTransformersCLICommand):
|
42 |
+
@staticmethod
|
43 |
+
def register_subcommand(parser: ArgumentParser):
|
44 |
+
add_new_model_parser = parser.add_parser("add-new-model")
|
45 |
+
add_new_model_parser.add_argument("--testing", action="store_true", help="If in testing mode.")
|
46 |
+
add_new_model_parser.add_argument("--testing_file", type=str, help="Configuration file on which to run.")
|
47 |
+
add_new_model_parser.add_argument(
|
48 |
+
"--path", type=str, help="Path to cookiecutter. Should only be used for testing purposes."
|
49 |
+
)
|
50 |
+
add_new_model_parser.set_defaults(func=add_new_model_command_factory)
|
51 |
+
|
52 |
+
def __init__(self, testing: bool, testing_file: str, path=None, *args):
|
53 |
+
self._testing = testing
|
54 |
+
self._testing_file = testing_file
|
55 |
+
self._path = path
|
56 |
+
|
57 |
+
def run(self):
|
58 |
+
warnings.warn(
|
59 |
+
"The command `transformers-cli add-new-model` is deprecated and will be removed in v5 of Transformers. "
|
60 |
+
"It is not actively maintained anymore, so might give a result that won't pass all tests and quality "
|
61 |
+
"checks, you should use `transformers-cli add-new-model-like` instead."
|
62 |
+
)
|
63 |
+
if not _has_cookiecutter:
|
64 |
+
raise ImportError(
|
65 |
+
"Model creation dependencies are required to use the `add_new_model` command. Install them by running "
|
66 |
+
"the following at the root of your `transformers` clone:\n\n\t$ pip install -e .[modelcreation]\n"
|
67 |
+
)
|
68 |
+
# Ensure that there is no other `cookiecutter-template-xxx` directory in the current working directory
|
69 |
+
directories = [directory for directory in os.listdir() if "cookiecutter-template-" == directory[:22]]
|
70 |
+
if len(directories) > 0:
|
71 |
+
raise ValueError(
|
72 |
+
"Several directories starting with `cookiecutter-template-` in current working directory. "
|
73 |
+
"Please clean your directory by removing all folders starting with `cookiecutter-template-` or "
|
74 |
+
"change your working directory."
|
75 |
+
)
|
76 |
+
|
77 |
+
path_to_transformer_root = (
|
78 |
+
Path(__file__).parent.parent.parent.parent if self._path is None else Path(self._path).parent.parent
|
79 |
+
)
|
80 |
+
path_to_cookiecutter = path_to_transformer_root / "templates" / "adding_a_new_model"
|
81 |
+
|
82 |
+
# Execute cookiecutter
|
83 |
+
if not self._testing:
|
84 |
+
cookiecutter(str(path_to_cookiecutter))
|
85 |
+
else:
|
86 |
+
with open(self._testing_file, "r") as configuration_file:
|
87 |
+
testing_configuration = json.load(configuration_file)
|
88 |
+
|
89 |
+
cookiecutter(
|
90 |
+
str(path_to_cookiecutter if self._path is None else self._path),
|
91 |
+
no_input=True,
|
92 |
+
extra_context=testing_configuration,
|
93 |
+
)
|
94 |
+
|
95 |
+
directory = [directory for directory in os.listdir() if "cookiecutter-template-" in directory[:22]][0]
|
96 |
+
|
97 |
+
# Retrieve configuration
|
98 |
+
with open(directory + "/configuration.json", "r") as configuration_file:
|
99 |
+
configuration = json.load(configuration_file)
|
100 |
+
|
101 |
+
lowercase_model_name = configuration["lowercase_modelname"]
|
102 |
+
generate_tensorflow_pytorch_and_flax = configuration["generate_tensorflow_pytorch_and_flax"]
|
103 |
+
os.remove(f"{directory}/configuration.json")
|
104 |
+
|
105 |
+
output_pytorch = "PyTorch" in generate_tensorflow_pytorch_and_flax
|
106 |
+
output_tensorflow = "TensorFlow" in generate_tensorflow_pytorch_and_flax
|
107 |
+
output_flax = "Flax" in generate_tensorflow_pytorch_and_flax
|
108 |
+
|
109 |
+
model_dir = f"{path_to_transformer_root}/src/transformers/models/{lowercase_model_name}"
|
110 |
+
os.makedirs(model_dir, exist_ok=True)
|
111 |
+
os.makedirs(f"{path_to_transformer_root}/tests/models/{lowercase_model_name}", exist_ok=True)
|
112 |
+
|
113 |
+
# Tests require submodules as they have parent imports
|
114 |
+
with open(f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/__init__.py", "w"):
|
115 |
+
pass
|
116 |
+
|
117 |
+
shutil.move(
|
118 |
+
f"{directory}/__init__.py",
|
119 |
+
f"{model_dir}/__init__.py",
|
120 |
+
)
|
121 |
+
shutil.move(
|
122 |
+
f"{directory}/configuration_{lowercase_model_name}.py",
|
123 |
+
f"{model_dir}/configuration_{lowercase_model_name}.py",
|
124 |
+
)
|
125 |
+
|
126 |
+
def remove_copy_lines(path):
|
127 |
+
with open(path, "r") as f:
|
128 |
+
lines = f.readlines()
|
129 |
+
with open(path, "w") as f:
|
130 |
+
for line in lines:
|
131 |
+
if "# Copied from transformers." not in line:
|
132 |
+
f.write(line)
|
133 |
+
|
134 |
+
if output_pytorch:
|
135 |
+
if not self._testing:
|
136 |
+
remove_copy_lines(f"{directory}/modeling_{lowercase_model_name}.py")
|
137 |
+
|
138 |
+
shutil.move(
|
139 |
+
f"{directory}/modeling_{lowercase_model_name}.py",
|
140 |
+
f"{model_dir}/modeling_{lowercase_model_name}.py",
|
141 |
+
)
|
142 |
+
|
143 |
+
shutil.move(
|
144 |
+
f"{directory}/test_modeling_{lowercase_model_name}.py",
|
145 |
+
f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_{lowercase_model_name}.py",
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
os.remove(f"{directory}/modeling_{lowercase_model_name}.py")
|
149 |
+
os.remove(f"{directory}/test_modeling_{lowercase_model_name}.py")
|
150 |
+
|
151 |
+
if output_tensorflow:
|
152 |
+
if not self._testing:
|
153 |
+
remove_copy_lines(f"{directory}/modeling_tf_{lowercase_model_name}.py")
|
154 |
+
|
155 |
+
shutil.move(
|
156 |
+
f"{directory}/modeling_tf_{lowercase_model_name}.py",
|
157 |
+
f"{model_dir}/modeling_tf_{lowercase_model_name}.py",
|
158 |
+
)
|
159 |
+
|
160 |
+
shutil.move(
|
161 |
+
f"{directory}/test_modeling_tf_{lowercase_model_name}.py",
|
162 |
+
f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_tf_{lowercase_model_name}.py",
|
163 |
+
)
|
164 |
+
else:
|
165 |
+
os.remove(f"{directory}/modeling_tf_{lowercase_model_name}.py")
|
166 |
+
os.remove(f"{directory}/test_modeling_tf_{lowercase_model_name}.py")
|
167 |
+
|
168 |
+
if output_flax:
|
169 |
+
if not self._testing:
|
170 |
+
remove_copy_lines(f"{directory}/modeling_flax_{lowercase_model_name}.py")
|
171 |
+
|
172 |
+
shutil.move(
|
173 |
+
f"{directory}/modeling_flax_{lowercase_model_name}.py",
|
174 |
+
f"{model_dir}/modeling_flax_{lowercase_model_name}.py",
|
175 |
+
)
|
176 |
+
|
177 |
+
shutil.move(
|
178 |
+
f"{directory}/test_modeling_flax_{lowercase_model_name}.py",
|
179 |
+
f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_flax_{lowercase_model_name}.py",
|
180 |
+
)
|
181 |
+
else:
|
182 |
+
os.remove(f"{directory}/modeling_flax_{lowercase_model_name}.py")
|
183 |
+
os.remove(f"{directory}/test_modeling_flax_{lowercase_model_name}.py")
|
184 |
+
|
185 |
+
shutil.move(
|
186 |
+
f"{directory}/{lowercase_model_name}.md",
|
187 |
+
f"{path_to_transformer_root}/docs/source/en/model_doc/{lowercase_model_name}.md",
|
188 |
+
)
|
189 |
+
|
190 |
+
shutil.move(
|
191 |
+
f"{directory}/tokenization_{lowercase_model_name}.py",
|
192 |
+
f"{model_dir}/tokenization_{lowercase_model_name}.py",
|
193 |
+
)
|
194 |
+
|
195 |
+
shutil.move(
|
196 |
+
f"{directory}/tokenization_fast_{lowercase_model_name}.py",
|
197 |
+
f"{model_dir}/tokenization_{lowercase_model_name}_fast.py",
|
198 |
+
)
|
199 |
+
|
200 |
+
from os import fdopen, remove
|
201 |
+
from shutil import copymode, move
|
202 |
+
from tempfile import mkstemp
|
203 |
+
|
204 |
+
def replace(original_file: str, line_to_copy_below: str, lines_to_copy: List[str]):
|
205 |
+
# Create temp file
|
206 |
+
fh, abs_path = mkstemp()
|
207 |
+
line_found = False
|
208 |
+
with fdopen(fh, "w") as new_file:
|
209 |
+
with open(original_file) as old_file:
|
210 |
+
for line in old_file:
|
211 |
+
new_file.write(line)
|
212 |
+
if line_to_copy_below in line:
|
213 |
+
line_found = True
|
214 |
+
for line_to_copy in lines_to_copy:
|
215 |
+
new_file.write(line_to_copy)
|
216 |
+
|
217 |
+
if not line_found:
|
218 |
+
raise ValueError(f"Line {line_to_copy_below} was not found in file.")
|
219 |
+
|
220 |
+
# Copy the file permissions from the old file to the new file
|
221 |
+
copymode(original_file, abs_path)
|
222 |
+
# Remove original file
|
223 |
+
remove(original_file)
|
224 |
+
# Move new file
|
225 |
+
move(abs_path, original_file)
|
226 |
+
|
227 |
+
def skip_units(line):
|
228 |
+
return (
|
229 |
+
("generating PyTorch" in line and not output_pytorch)
|
230 |
+
or ("generating TensorFlow" in line and not output_tensorflow)
|
231 |
+
or ("generating Flax" in line and not output_flax)
|
232 |
+
)
|
233 |
+
|
234 |
+
def replace_in_files(path_to_datafile):
|
235 |
+
with open(path_to_datafile) as datafile:
|
236 |
+
lines_to_copy = []
|
237 |
+
skip_file = False
|
238 |
+
skip_snippet = False
|
239 |
+
for line in datafile:
|
240 |
+
if "# To replace in: " in line and "##" not in line:
|
241 |
+
file_to_replace_in = line.split('"')[1]
|
242 |
+
skip_file = skip_units(line)
|
243 |
+
elif "# Below: " in line and "##" not in line:
|
244 |
+
line_to_copy_below = line.split('"')[1]
|
245 |
+
skip_snippet = skip_units(line)
|
246 |
+
elif "# End." in line and "##" not in line:
|
247 |
+
if not skip_file and not skip_snippet:
|
248 |
+
replace(file_to_replace_in, line_to_copy_below, lines_to_copy)
|
249 |
+
|
250 |
+
lines_to_copy = []
|
251 |
+
elif "# Replace with" in line and "##" not in line:
|
252 |
+
lines_to_copy = []
|
253 |
+
elif "##" not in line:
|
254 |
+
lines_to_copy.append(line)
|
255 |
+
|
256 |
+
remove(path_to_datafile)
|
257 |
+
|
258 |
+
replace_in_files(f"{directory}/to_replace_{lowercase_model_name}.py")
|
259 |
+
os.rmdir(directory)
|
transformers_4_35_0/commands/add_new_model_like.py
ADDED
@@ -0,0 +1,1763 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import difflib
|
16 |
+
import json
|
17 |
+
import os
|
18 |
+
import re
|
19 |
+
from argparse import ArgumentParser, Namespace
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from datetime import date
|
22 |
+
from itertools import chain
|
23 |
+
from pathlib import Path
|
24 |
+
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union
|
25 |
+
|
26 |
+
import yaml
|
27 |
+
|
28 |
+
from ..models import auto as auto_module
|
29 |
+
from ..models.auto.configuration_auto import model_type_to_module_name
|
30 |
+
from ..utils import is_flax_available, is_tf_available, is_torch_available, logging
|
31 |
+
from . import BaseTransformersCLICommand
|
32 |
+
|
33 |
+
|
34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
35 |
+
|
36 |
+
|
37 |
+
CURRENT_YEAR = date.today().year
|
38 |
+
TRANSFORMERS_PATH = Path(__file__).parent.parent
|
39 |
+
REPO_PATH = TRANSFORMERS_PATH.parent.parent
|
40 |
+
|
41 |
+
|
42 |
+
@dataclass
|
43 |
+
class ModelPatterns:
|
44 |
+
"""
|
45 |
+
Holds the basic information about a new model for the add-new-model-like command.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
model_name (`str`): The model name.
|
49 |
+
checkpoint (`str`): The checkpoint to use for doc examples.
|
50 |
+
model_type (`str`, *optional*):
|
51 |
+
The model type, the identifier used internally in the library like `bert` or `xlm-roberta`. Will default to
|
52 |
+
`model_name` lowercased with spaces replaced with minuses (-).
|
53 |
+
model_lower_cased (`str`, *optional*):
|
54 |
+
The lowercased version of the model name, to use for the module name or function names. Will default to
|
55 |
+
`model_name` lowercased with spaces and minuses replaced with underscores.
|
56 |
+
model_camel_cased (`str`, *optional*):
|
57 |
+
The camel-cased version of the model name, to use for the class names. Will default to `model_name`
|
58 |
+
camel-cased (with spaces and minuses both considered as word separators.
|
59 |
+
model_upper_cased (`str`, *optional*):
|
60 |
+
The uppercased version of the model name, to use for the constant names. Will default to `model_name`
|
61 |
+
uppercased with spaces and minuses replaced with underscores.
|
62 |
+
config_class (`str`, *optional*):
|
63 |
+
The tokenizer class associated with this model. Will default to `"{model_camel_cased}Config"`.
|
64 |
+
tokenizer_class (`str`, *optional*):
|
65 |
+
The tokenizer class associated with this model (leave to `None` for models that don't use a tokenizer).
|
66 |
+
image_processor_class (`str`, *optional*):
|
67 |
+
The image processor class associated with this model (leave to `None` for models that don't use an image
|
68 |
+
processor).
|
69 |
+
feature_extractor_class (`str`, *optional*):
|
70 |
+
The feature extractor class associated with this model (leave to `None` for models that don't use a feature
|
71 |
+
extractor).
|
72 |
+
processor_class (`str`, *optional*):
|
73 |
+
The processor class associated with this model (leave to `None` for models that don't use a processor).
|
74 |
+
"""
|
75 |
+
|
76 |
+
model_name: str
|
77 |
+
checkpoint: str
|
78 |
+
model_type: Optional[str] = None
|
79 |
+
model_lower_cased: Optional[str] = None
|
80 |
+
model_camel_cased: Optional[str] = None
|
81 |
+
model_upper_cased: Optional[str] = None
|
82 |
+
config_class: Optional[str] = None
|
83 |
+
tokenizer_class: Optional[str] = None
|
84 |
+
image_processor_class: Optional[str] = None
|
85 |
+
feature_extractor_class: Optional[str] = None
|
86 |
+
processor_class: Optional[str] = None
|
87 |
+
|
88 |
+
def __post_init__(self):
|
89 |
+
if self.model_type is None:
|
90 |
+
self.model_type = self.model_name.lower().replace(" ", "-")
|
91 |
+
if self.model_lower_cased is None:
|
92 |
+
self.model_lower_cased = self.model_name.lower().replace(" ", "_").replace("-", "_")
|
93 |
+
if self.model_camel_cased is None:
|
94 |
+
# Split the model name on - and space
|
95 |
+
words = self.model_name.split(" ")
|
96 |
+
words = list(chain(*[w.split("-") for w in words]))
|
97 |
+
# Make sure each word is capitalized
|
98 |
+
words = [w[0].upper() + w[1:] for w in words]
|
99 |
+
self.model_camel_cased = "".join(words)
|
100 |
+
if self.model_upper_cased is None:
|
101 |
+
self.model_upper_cased = self.model_name.upper().replace(" ", "_").replace("-", "_")
|
102 |
+
if self.config_class is None:
|
103 |
+
self.config_class = f"{self.model_camel_cased}Config"
|
104 |
+
|
105 |
+
|
106 |
+
ATTRIBUTE_TO_PLACEHOLDER = {
|
107 |
+
"config_class": "[CONFIG_CLASS]",
|
108 |
+
"tokenizer_class": "[TOKENIZER_CLASS]",
|
109 |
+
"image_processor_class": "[IMAGE_PROCESSOR_CLASS]",
|
110 |
+
"feature_extractor_class": "[FEATURE_EXTRACTOR_CLASS]",
|
111 |
+
"processor_class": "[PROCESSOR_CLASS]",
|
112 |
+
"checkpoint": "[CHECKPOINT]",
|
113 |
+
"model_type": "[MODEL_TYPE]",
|
114 |
+
"model_upper_cased": "[MODEL_UPPER_CASED]",
|
115 |
+
"model_camel_cased": "[MODEL_CAMELCASED]",
|
116 |
+
"model_lower_cased": "[MODEL_LOWER_CASED]",
|
117 |
+
"model_name": "[MODEL_NAME]",
|
118 |
+
}
|
119 |
+
|
120 |
+
|
121 |
+
def is_empty_line(line: str) -> bool:
|
122 |
+
"""
|
123 |
+
Determines whether a line is empty or not.
|
124 |
+
"""
|
125 |
+
return len(line) == 0 or line.isspace()
|
126 |
+
|
127 |
+
|
128 |
+
def find_indent(line: str) -> int:
|
129 |
+
"""
|
130 |
+
Returns the number of spaces that start a line indent.
|
131 |
+
"""
|
132 |
+
search = re.search(r"^(\s*)(?:\S|$)", line)
|
133 |
+
if search is None:
|
134 |
+
return 0
|
135 |
+
return len(search.groups()[0])
|
136 |
+
|
137 |
+
|
138 |
+
def parse_module_content(content: str) -> List[str]:
|
139 |
+
"""
|
140 |
+
Parse the content of a module in the list of objects it defines.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
content (`str`): The content to parse
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
`List[str]`: The list of objects defined in the module.
|
147 |
+
"""
|
148 |
+
objects = []
|
149 |
+
current_object = []
|
150 |
+
lines = content.split("\n")
|
151 |
+
# Doc-styler takes everything between two triple quotes in docstrings, so we need a fake """ here to go with this.
|
152 |
+
end_markers = [")", "]", "}", '"""']
|
153 |
+
|
154 |
+
for line in lines:
|
155 |
+
# End of an object
|
156 |
+
is_valid_object = len(current_object) > 0
|
157 |
+
if is_valid_object and len(current_object) == 1:
|
158 |
+
is_valid_object = not current_object[0].startswith("# Copied from")
|
159 |
+
if not is_empty_line(line) and find_indent(line) == 0 and is_valid_object:
|
160 |
+
# Closing parts should be included in current object
|
161 |
+
if line in end_markers:
|
162 |
+
current_object.append(line)
|
163 |
+
objects.append("\n".join(current_object))
|
164 |
+
current_object = []
|
165 |
+
else:
|
166 |
+
objects.append("\n".join(current_object))
|
167 |
+
current_object = [line]
|
168 |
+
else:
|
169 |
+
current_object.append(line)
|
170 |
+
|
171 |
+
# Add last object
|
172 |
+
if len(current_object) > 0:
|
173 |
+
objects.append("\n".join(current_object))
|
174 |
+
|
175 |
+
return objects
|
176 |
+
|
177 |
+
|
178 |
+
def extract_block(content: str, indent_level: int = 0) -> str:
|
179 |
+
"""Return the first block in `content` with the indent level `indent_level`.
|
180 |
+
|
181 |
+
The first line in `content` should be indented at `indent_level` level, otherwise an error will be thrown.
|
182 |
+
|
183 |
+
This method will immediately stop the search when a (non-empty) line with indent level less than `indent_level` is
|
184 |
+
encountered.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
content (`str`): The content to parse
|
188 |
+
indent_level (`int`, *optional*, default to 0): The indent level of the blocks to search for
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
`str`: The first block in `content` with the indent level `indent_level`.
|
192 |
+
"""
|
193 |
+
current_object = []
|
194 |
+
lines = content.split("\n")
|
195 |
+
# Doc-styler takes everything between two triple quotes in docstrings, so we need a fake """ here to go with this.
|
196 |
+
end_markers = [")", "]", "}", '"""']
|
197 |
+
|
198 |
+
for idx, line in enumerate(lines):
|
199 |
+
if idx == 0 and indent_level > 0 and not is_empty_line(line) and find_indent(line) != indent_level:
|
200 |
+
raise ValueError(
|
201 |
+
f"When `indent_level > 0`, the first line in `content` should have indent level {indent_level}. Got "
|
202 |
+
f"{find_indent(line)} instead."
|
203 |
+
)
|
204 |
+
|
205 |
+
if find_indent(line) < indent_level and not is_empty_line(line):
|
206 |
+
break
|
207 |
+
|
208 |
+
# End of an object
|
209 |
+
is_valid_object = len(current_object) > 0
|
210 |
+
if (
|
211 |
+
not is_empty_line(line)
|
212 |
+
and not line.endswith(":")
|
213 |
+
and find_indent(line) == indent_level
|
214 |
+
and is_valid_object
|
215 |
+
):
|
216 |
+
# Closing parts should be included in current object
|
217 |
+
if line.lstrip() in end_markers:
|
218 |
+
current_object.append(line)
|
219 |
+
return "\n".join(current_object)
|
220 |
+
else:
|
221 |
+
current_object.append(line)
|
222 |
+
|
223 |
+
# Add last object
|
224 |
+
if len(current_object) > 0:
|
225 |
+
return "\n".join(current_object)
|
226 |
+
|
227 |
+
|
228 |
+
def add_content_to_text(
|
229 |
+
text: str,
|
230 |
+
content: str,
|
231 |
+
add_after: Optional[Union[str, Pattern]] = None,
|
232 |
+
add_before: Optional[Union[str, Pattern]] = None,
|
233 |
+
exact_match: bool = False,
|
234 |
+
) -> str:
|
235 |
+
"""
|
236 |
+
A utility to add some content inside a given text.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
text (`str`): The text in which we want to insert some content.
|
240 |
+
content (`str`): The content to add.
|
241 |
+
add_after (`str` or `Pattern`):
|
242 |
+
The pattern to test on a line of `text`, the new content is added after the first instance matching it.
|
243 |
+
add_before (`str` or `Pattern`):
|
244 |
+
The pattern to test on a line of `text`, the new content is added before the first instance matching it.
|
245 |
+
exact_match (`bool`, *optional*, defaults to `False`):
|
246 |
+
A line is considered a match with `add_after` or `add_before` if it matches exactly when `exact_match=True`,
|
247 |
+
otherwise, if `add_after`/`add_before` is present in the line.
|
248 |
+
|
249 |
+
<Tip warning={true}>
|
250 |
+
|
251 |
+
The arguments `add_after` and `add_before` are mutually exclusive, and one exactly needs to be provided.
|
252 |
+
|
253 |
+
</Tip>
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
`str`: The text with the new content added if a match was found.
|
257 |
+
"""
|
258 |
+
if add_after is None and add_before is None:
|
259 |
+
raise ValueError("You need to pass either `add_after` or `add_before`")
|
260 |
+
if add_after is not None and add_before is not None:
|
261 |
+
raise ValueError("You can't pass both `add_after` or `add_before`")
|
262 |
+
pattern = add_after if add_before is None else add_before
|
263 |
+
|
264 |
+
def this_is_the_line(line):
|
265 |
+
if isinstance(pattern, Pattern):
|
266 |
+
return pattern.search(line) is not None
|
267 |
+
elif exact_match:
|
268 |
+
return pattern == line
|
269 |
+
else:
|
270 |
+
return pattern in line
|
271 |
+
|
272 |
+
new_lines = []
|
273 |
+
for line in text.split("\n"):
|
274 |
+
if this_is_the_line(line):
|
275 |
+
if add_before is not None:
|
276 |
+
new_lines.append(content)
|
277 |
+
new_lines.append(line)
|
278 |
+
if add_after is not None:
|
279 |
+
new_lines.append(content)
|
280 |
+
else:
|
281 |
+
new_lines.append(line)
|
282 |
+
|
283 |
+
return "\n".join(new_lines)
|
284 |
+
|
285 |
+
|
286 |
+
def add_content_to_file(
|
287 |
+
file_name: Union[str, os.PathLike],
|
288 |
+
content: str,
|
289 |
+
add_after: Optional[Union[str, Pattern]] = None,
|
290 |
+
add_before: Optional[Union[str, Pattern]] = None,
|
291 |
+
exact_match: bool = False,
|
292 |
+
):
|
293 |
+
"""
|
294 |
+
A utility to add some content inside a given file.
|
295 |
+
|
296 |
+
Args:
|
297 |
+
file_name (`str` or `os.PathLike`): The name of the file in which we want to insert some content.
|
298 |
+
content (`str`): The content to add.
|
299 |
+
add_after (`str` or `Pattern`):
|
300 |
+
The pattern to test on a line of `text`, the new content is added after the first instance matching it.
|
301 |
+
add_before (`str` or `Pattern`):
|
302 |
+
The pattern to test on a line of `text`, the new content is added before the first instance matching it.
|
303 |
+
exact_match (`bool`, *optional*, defaults to `False`):
|
304 |
+
A line is considered a match with `add_after` or `add_before` if it matches exactly when `exact_match=True`,
|
305 |
+
otherwise, if `add_after`/`add_before` is present in the line.
|
306 |
+
|
307 |
+
<Tip warning={true}>
|
308 |
+
|
309 |
+
The arguments `add_after` and `add_before` are mutually exclusive, and one exactly needs to be provided.
|
310 |
+
|
311 |
+
</Tip>
|
312 |
+
"""
|
313 |
+
with open(file_name, "r", encoding="utf-8") as f:
|
314 |
+
old_content = f.read()
|
315 |
+
|
316 |
+
new_content = add_content_to_text(
|
317 |
+
old_content, content, add_after=add_after, add_before=add_before, exact_match=exact_match
|
318 |
+
)
|
319 |
+
|
320 |
+
with open(file_name, "w", encoding="utf-8") as f:
|
321 |
+
f.write(new_content)
|
322 |
+
|
323 |
+
|
324 |
+
def replace_model_patterns(
|
325 |
+
text: str, old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns
|
326 |
+
) -> Tuple[str, str]:
|
327 |
+
"""
|
328 |
+
Replace all patterns present in a given text.
|
329 |
+
|
330 |
+
Args:
|
331 |
+
text (`str`): The text to treat.
|
332 |
+
old_model_patterns (`ModelPatterns`): The patterns for the old model.
|
333 |
+
new_model_patterns (`ModelPatterns`): The patterns for the new model.
|
334 |
+
|
335 |
+
Returns:
|
336 |
+
`Tuple(str, str)`: A tuple of with the treated text and the replacement actually done in it.
|
337 |
+
"""
|
338 |
+
# The order is crucially important as we will check and replace in that order. For instance the config probably
|
339 |
+
# contains the camel-cased named, but will be treated before.
|
340 |
+
attributes_to_check = ["config_class"]
|
341 |
+
# Add relevant preprocessing classes
|
342 |
+
for attr in ["tokenizer_class", "image_processor_class", "feature_extractor_class", "processor_class"]:
|
343 |
+
if getattr(old_model_patterns, attr) is not None and getattr(new_model_patterns, attr) is not None:
|
344 |
+
attributes_to_check.append(attr)
|
345 |
+
|
346 |
+
# Special cases for checkpoint and model_type
|
347 |
+
if old_model_patterns.checkpoint not in [old_model_patterns.model_type, old_model_patterns.model_lower_cased]:
|
348 |
+
attributes_to_check.append("checkpoint")
|
349 |
+
if old_model_patterns.model_type != old_model_patterns.model_lower_cased:
|
350 |
+
attributes_to_check.append("model_type")
|
351 |
+
else:
|
352 |
+
text = re.sub(
|
353 |
+
rf'(\s*)model_type = "{old_model_patterns.model_type}"',
|
354 |
+
r'\1model_type = "[MODEL_TYPE]"',
|
355 |
+
text,
|
356 |
+
)
|
357 |
+
|
358 |
+
# Special case when the model camel cased and upper cased names are the same for the old model (like for GPT2) but
|
359 |
+
# not the new one. We can't just do a replace in all the text and will need a special regex
|
360 |
+
if old_model_patterns.model_upper_cased == old_model_patterns.model_camel_cased:
|
361 |
+
old_model_value = old_model_patterns.model_upper_cased
|
362 |
+
if re.search(rf"{old_model_value}_[A-Z_]*[^A-Z_]", text) is not None:
|
363 |
+
text = re.sub(rf"{old_model_value}([A-Z_]*)([^a-zA-Z_])", r"[MODEL_UPPER_CASED]\1\2", text)
|
364 |
+
else:
|
365 |
+
attributes_to_check.append("model_upper_cased")
|
366 |
+
|
367 |
+
attributes_to_check.extend(["model_camel_cased", "model_lower_cased", "model_name"])
|
368 |
+
|
369 |
+
# Now let's replace every other attribute by their placeholder
|
370 |
+
for attr in attributes_to_check:
|
371 |
+
text = text.replace(getattr(old_model_patterns, attr), ATTRIBUTE_TO_PLACEHOLDER[attr])
|
372 |
+
|
373 |
+
# Finally we can replace the placeholder byt the new values.
|
374 |
+
replacements = []
|
375 |
+
for attr, placeholder in ATTRIBUTE_TO_PLACEHOLDER.items():
|
376 |
+
if placeholder in text:
|
377 |
+
replacements.append((getattr(old_model_patterns, attr), getattr(new_model_patterns, attr)))
|
378 |
+
text = text.replace(placeholder, getattr(new_model_patterns, attr))
|
379 |
+
|
380 |
+
# If we have two inconsistent replacements, we don't return anything (ex: GPT2->GPT_NEW and GPT2->GPTNew)
|
381 |
+
old_replacement_values = [old for old, new in replacements]
|
382 |
+
if len(set(old_replacement_values)) != len(old_replacement_values):
|
383 |
+
return text, ""
|
384 |
+
|
385 |
+
replacements = simplify_replacements(replacements)
|
386 |
+
replacements = [f"{old}->{new}" for old, new in replacements]
|
387 |
+
return text, ",".join(replacements)
|
388 |
+
|
389 |
+
|
390 |
+
def simplify_replacements(replacements):
|
391 |
+
"""
|
392 |
+
Simplify a list of replacement patterns to make sure there are no needless ones.
|
393 |
+
|
394 |
+
For instance in the sequence "Bert->BertNew, BertConfig->BertNewConfig, bert->bert_new", the replacement
|
395 |
+
"BertConfig->BertNewConfig" is implied by "Bert->BertNew" so not needed.
|
396 |
+
|
397 |
+
Args:
|
398 |
+
replacements (`List[Tuple[str, str]]`): List of patterns (old, new)
|
399 |
+
|
400 |
+
Returns:
|
401 |
+
`List[Tuple[str, str]]`: The list of patterns simplified.
|
402 |
+
"""
|
403 |
+
if len(replacements) <= 1:
|
404 |
+
# Nothing to simplify
|
405 |
+
return replacements
|
406 |
+
|
407 |
+
# Next let's sort replacements by length as a replacement can only "imply" another replacement if it's shorter.
|
408 |
+
replacements.sort(key=lambda x: len(x[0]))
|
409 |
+
|
410 |
+
idx = 0
|
411 |
+
while idx < len(replacements):
|
412 |
+
old, new = replacements[idx]
|
413 |
+
# Loop through all replacements after
|
414 |
+
j = idx + 1
|
415 |
+
while j < len(replacements):
|
416 |
+
old_2, new_2 = replacements[j]
|
417 |
+
# If the replacement is implied by the current one, we can drop it.
|
418 |
+
if old_2.replace(old, new) == new_2:
|
419 |
+
replacements.pop(j)
|
420 |
+
else:
|
421 |
+
j += 1
|
422 |
+
idx += 1
|
423 |
+
|
424 |
+
return replacements
|
425 |
+
|
426 |
+
|
427 |
+
def get_module_from_file(module_file: Union[str, os.PathLike]) -> str:
|
428 |
+
"""
|
429 |
+
Returns the module name corresponding to a module file.
|
430 |
+
"""
|
431 |
+
full_module_path = Path(module_file).absolute()
|
432 |
+
module_parts = full_module_path.with_suffix("").parts
|
433 |
+
|
434 |
+
# Find the first part named transformers, starting from the end.
|
435 |
+
idx = len(module_parts) - 1
|
436 |
+
while idx >= 0 and module_parts[idx] != "transformers":
|
437 |
+
idx -= 1
|
438 |
+
if idx < 0:
|
439 |
+
raise ValueError(f"{module_file} is not a transformers module.")
|
440 |
+
|
441 |
+
return ".".join(module_parts[idx:])
|
442 |
+
|
443 |
+
|
444 |
+
SPECIAL_PATTERNS = {
|
445 |
+
"_CHECKPOINT_FOR_DOC =": "checkpoint",
|
446 |
+
"_CONFIG_FOR_DOC =": "config_class",
|
447 |
+
"_TOKENIZER_FOR_DOC =": "tokenizer_class",
|
448 |
+
"_IMAGE_PROCESSOR_FOR_DOC =": "image_processor_class",
|
449 |
+
"_FEAT_EXTRACTOR_FOR_DOC =": "feature_extractor_class",
|
450 |
+
"_PROCESSOR_FOR_DOC =": "processor_class",
|
451 |
+
}
|
452 |
+
|
453 |
+
|
454 |
+
_re_class_func = re.compile(r"^(?:class|def)\s+([^\s:\(]+)\s*(?:\(|\:)", flags=re.MULTILINE)
|
455 |
+
|
456 |
+
|
457 |
+
def remove_attributes(obj, target_attr):
|
458 |
+
"""Remove `target_attr` in `obj`."""
|
459 |
+
lines = obj.split(os.linesep)
|
460 |
+
|
461 |
+
target_idx = None
|
462 |
+
for idx, line in enumerate(lines):
|
463 |
+
# search for assignment
|
464 |
+
if line.lstrip().startswith(f"{target_attr} = "):
|
465 |
+
target_idx = idx
|
466 |
+
break
|
467 |
+
# search for function/method definition
|
468 |
+
elif line.lstrip().startswith(f"def {target_attr}("):
|
469 |
+
target_idx = idx
|
470 |
+
break
|
471 |
+
|
472 |
+
# target not found
|
473 |
+
if target_idx is None:
|
474 |
+
return obj
|
475 |
+
|
476 |
+
line = lines[target_idx]
|
477 |
+
indent_level = find_indent(line)
|
478 |
+
# forward pass to find the ending of the block (including empty lines)
|
479 |
+
parsed = extract_block("\n".join(lines[target_idx:]), indent_level)
|
480 |
+
num_lines = len(parsed.split("\n"))
|
481 |
+
for idx in range(num_lines):
|
482 |
+
lines[target_idx + idx] = None
|
483 |
+
|
484 |
+
# backward pass to find comments or decorator
|
485 |
+
for idx in range(target_idx - 1, -1, -1):
|
486 |
+
line = lines[idx]
|
487 |
+
if (line.lstrip().startswith("#") or line.lstrip().startswith("@")) and find_indent(line) == indent_level:
|
488 |
+
lines[idx] = None
|
489 |
+
else:
|
490 |
+
break
|
491 |
+
|
492 |
+
new_obj = os.linesep.join([x for x in lines if x is not None])
|
493 |
+
|
494 |
+
return new_obj
|
495 |
+
|
496 |
+
|
497 |
+
def duplicate_module(
|
498 |
+
module_file: Union[str, os.PathLike],
|
499 |
+
old_model_patterns: ModelPatterns,
|
500 |
+
new_model_patterns: ModelPatterns,
|
501 |
+
dest_file: Optional[str] = None,
|
502 |
+
add_copied_from: bool = True,
|
503 |
+
attrs_to_remove: List[str] = None,
|
504 |
+
):
|
505 |
+
"""
|
506 |
+
Create a new module from an existing one and adapting all function and classes names from old patterns to new ones.
|
507 |
+
|
508 |
+
Args:
|
509 |
+
module_file (`str` or `os.PathLike`): Path to the module to duplicate.
|
510 |
+
old_model_patterns (`ModelPatterns`): The patterns for the old model.
|
511 |
+
new_model_patterns (`ModelPatterns`): The patterns for the new model.
|
512 |
+
dest_file (`str` or `os.PathLike`, *optional*): Path to the new module.
|
513 |
+
add_copied_from (`bool`, *optional*, defaults to `True`):
|
514 |
+
Whether or not to add `# Copied from` statements in the duplicated module.
|
515 |
+
"""
|
516 |
+
if dest_file is None:
|
517 |
+
dest_file = str(module_file).replace(
|
518 |
+
old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
|
519 |
+
)
|
520 |
+
|
521 |
+
with open(module_file, "r", encoding="utf-8") as f:
|
522 |
+
content = f.read()
|
523 |
+
|
524 |
+
content = re.sub(r"# Copyright (\d+)\s", f"# Copyright {CURRENT_YEAR} ", content)
|
525 |
+
objects = parse_module_content(content)
|
526 |
+
|
527 |
+
# Loop and treat all objects
|
528 |
+
new_objects = []
|
529 |
+
for obj in objects:
|
530 |
+
# Special cases
|
531 |
+
if "PRETRAINED_CONFIG_ARCHIVE_MAP = {" in obj:
|
532 |
+
# docstyle-ignore
|
533 |
+
obj = (
|
534 |
+
f"{new_model_patterns.model_upper_cased}_PRETRAINED_CONFIG_ARCHIVE_MAP = "
|
535 |
+
+ "{"
|
536 |
+
+ f"""
|
537 |
+
"{new_model_patterns.checkpoint}": "https://huggingface.co/{new_model_patterns.checkpoint}/resolve/main/config.json",
|
538 |
+
"""
|
539 |
+
+ "}\n"
|
540 |
+
)
|
541 |
+
new_objects.append(obj)
|
542 |
+
continue
|
543 |
+
elif "PRETRAINED_MODEL_ARCHIVE_LIST = [" in obj:
|
544 |
+
if obj.startswith("TF_"):
|
545 |
+
prefix = "TF_"
|
546 |
+
elif obj.startswith("FLAX_"):
|
547 |
+
prefix = "FLAX_"
|
548 |
+
else:
|
549 |
+
prefix = ""
|
550 |
+
# docstyle-ignore
|
551 |
+
obj = f"""{prefix}{new_model_patterns.model_upper_cased}_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
552 |
+
"{new_model_patterns.checkpoint}",
|
553 |
+
# See all {new_model_patterns.model_name} models at https://huggingface.co/models?filter={new_model_patterns.model_type}
|
554 |
+
]
|
555 |
+
"""
|
556 |
+
new_objects.append(obj)
|
557 |
+
continue
|
558 |
+
|
559 |
+
special_pattern = False
|
560 |
+
for pattern, attr in SPECIAL_PATTERNS.items():
|
561 |
+
if pattern in obj:
|
562 |
+
obj = obj.replace(getattr(old_model_patterns, attr), getattr(new_model_patterns, attr))
|
563 |
+
new_objects.append(obj)
|
564 |
+
special_pattern = True
|
565 |
+
break
|
566 |
+
|
567 |
+
if special_pattern:
|
568 |
+
continue
|
569 |
+
|
570 |
+
# Regular classes functions
|
571 |
+
old_obj = obj
|
572 |
+
obj, replacement = replace_model_patterns(obj, old_model_patterns, new_model_patterns)
|
573 |
+
has_copied_from = re.search(r"^#\s+Copied from", obj, flags=re.MULTILINE) is not None
|
574 |
+
if add_copied_from and not has_copied_from and _re_class_func.search(obj) is not None and len(replacement) > 0:
|
575 |
+
# Copied from statement must be added just before the class/function definition, which may not be the
|
576 |
+
# first line because of decorators.
|
577 |
+
module_name = get_module_from_file(module_file)
|
578 |
+
old_object_name = _re_class_func.search(old_obj).groups()[0]
|
579 |
+
obj = add_content_to_text(
|
580 |
+
obj, f"# Copied from {module_name}.{old_object_name} with {replacement}", add_before=_re_class_func
|
581 |
+
)
|
582 |
+
# In all cases, we remove Copied from statement with indent on methods.
|
583 |
+
obj = re.sub("\n[ ]+# Copied from [^\n]*\n", "\n", obj)
|
584 |
+
|
585 |
+
new_objects.append(obj)
|
586 |
+
|
587 |
+
content = "\n".join(new_objects)
|
588 |
+
# Remove some attributes that we don't want to copy to the new file(s)
|
589 |
+
if attrs_to_remove is not None:
|
590 |
+
for attr in attrs_to_remove:
|
591 |
+
content = remove_attributes(content, target_attr=attr)
|
592 |
+
|
593 |
+
with open(dest_file, "w", encoding="utf-8") as f:
|
594 |
+
f.write(content)
|
595 |
+
|
596 |
+
|
597 |
+
def filter_framework_files(
|
598 |
+
files: List[Union[str, os.PathLike]], frameworks: Optional[List[str]] = None
|
599 |
+
) -> List[Union[str, os.PathLike]]:
|
600 |
+
"""
|
601 |
+
Filter a list of files to only keep the ones corresponding to a list of frameworks.
|
602 |
+
|
603 |
+
Args:
|
604 |
+
files (`List[Union[str, os.PathLike]]`): The list of files to filter.
|
605 |
+
frameworks (`List[str]`, *optional*): The list of allowed frameworks.
|
606 |
+
|
607 |
+
Returns:
|
608 |
+
`List[Union[str, os.PathLike]]`: The list of filtered files.
|
609 |
+
"""
|
610 |
+
if frameworks is None:
|
611 |
+
frameworks = get_default_frameworks()
|
612 |
+
|
613 |
+
framework_to_file = {}
|
614 |
+
others = []
|
615 |
+
for f in files:
|
616 |
+
parts = Path(f).name.split("_")
|
617 |
+
if "modeling" not in parts:
|
618 |
+
others.append(f)
|
619 |
+
continue
|
620 |
+
if "tf" in parts:
|
621 |
+
framework_to_file["tf"] = f
|
622 |
+
elif "flax" in parts:
|
623 |
+
framework_to_file["flax"] = f
|
624 |
+
else:
|
625 |
+
framework_to_file["pt"] = f
|
626 |
+
|
627 |
+
return [framework_to_file[f] for f in frameworks if f in framework_to_file] + others
|
628 |
+
|
629 |
+
|
630 |
+
def get_model_files(model_type: str, frameworks: Optional[List[str]] = None) -> Dict[str, Union[Path, List[Path]]]:
|
631 |
+
"""
|
632 |
+
Retrieves all the files associated to a model.
|
633 |
+
|
634 |
+
Args:
|
635 |
+
model_type (`str`): A valid model type (like "bert" or "gpt2")
|
636 |
+
frameworks (`List[str]`, *optional*):
|
637 |
+
If passed, will only keep the model files corresponding to the passed frameworks.
|
638 |
+
|
639 |
+
Returns:
|
640 |
+
`Dict[str, Union[Path, List[Path]]]`: A dictionary with the following keys:
|
641 |
+
- **doc_file** -- The documentation file for the model.
|
642 |
+
- **model_files** -- All the files in the model module.
|
643 |
+
- **test_files** -- The test files for the model.
|
644 |
+
"""
|
645 |
+
module_name = model_type_to_module_name(model_type)
|
646 |
+
|
647 |
+
model_module = TRANSFORMERS_PATH / "models" / module_name
|
648 |
+
model_files = list(model_module.glob("*.py"))
|
649 |
+
model_files = filter_framework_files(model_files, frameworks=frameworks)
|
650 |
+
|
651 |
+
doc_file = REPO_PATH / "docs" / "source" / "en" / "model_doc" / f"{model_type}.md"
|
652 |
+
|
653 |
+
# Basic pattern for test files
|
654 |
+
test_files = [
|
655 |
+
f"test_modeling_{module_name}.py",
|
656 |
+
f"test_modeling_tf_{module_name}.py",
|
657 |
+
f"test_modeling_flax_{module_name}.py",
|
658 |
+
f"test_tokenization_{module_name}.py",
|
659 |
+
f"test_image_processing_{module_name}.py",
|
660 |
+
f"test_feature_extraction_{module_name}.py",
|
661 |
+
f"test_processor_{module_name}.py",
|
662 |
+
]
|
663 |
+
test_files = filter_framework_files(test_files, frameworks=frameworks)
|
664 |
+
# Add the test directory
|
665 |
+
test_files = [REPO_PATH / "tests" / "models" / module_name / f for f in test_files]
|
666 |
+
# Filter by existing files
|
667 |
+
test_files = [f for f in test_files if f.exists()]
|
668 |
+
|
669 |
+
return {"doc_file": doc_file, "model_files": model_files, "module_name": module_name, "test_files": test_files}
|
670 |
+
|
671 |
+
|
672 |
+
_re_checkpoint_for_doc = re.compile(r"^_CHECKPOINT_FOR_DOC\s+=\s+(\S*)\s*$", flags=re.MULTILINE)
|
673 |
+
|
674 |
+
|
675 |
+
def find_base_model_checkpoint(
|
676 |
+
model_type: str, model_files: Optional[Dict[str, Union[Path, List[Path]]]] = None
|
677 |
+
) -> str:
|
678 |
+
"""
|
679 |
+
Finds the model checkpoint used in the docstrings for a given model.
|
680 |
+
|
681 |
+
Args:
|
682 |
+
model_type (`str`): A valid model type (like "bert" or "gpt2")
|
683 |
+
model_files (`Dict[str, Union[Path, List[Path]]`, *optional*):
|
684 |
+
The files associated to `model_type`. Can be passed to speed up the function, otherwise will be computed.
|
685 |
+
|
686 |
+
Returns:
|
687 |
+
`str`: The checkpoint used.
|
688 |
+
"""
|
689 |
+
if model_files is None:
|
690 |
+
model_files = get_model_files(model_type)
|
691 |
+
module_files = model_files["model_files"]
|
692 |
+
for fname in module_files:
|
693 |
+
if "modeling" not in str(fname):
|
694 |
+
continue
|
695 |
+
|
696 |
+
with open(fname, "r", encoding="utf-8") as f:
|
697 |
+
content = f.read()
|
698 |
+
if _re_checkpoint_for_doc.search(content) is not None:
|
699 |
+
checkpoint = _re_checkpoint_for_doc.search(content).groups()[0]
|
700 |
+
# Remove quotes
|
701 |
+
checkpoint = checkpoint.replace('"', "")
|
702 |
+
checkpoint = checkpoint.replace("'", "")
|
703 |
+
return checkpoint
|
704 |
+
|
705 |
+
# TODO: Find some kind of fallback if there is no _CHECKPOINT_FOR_DOC in any of the modeling file.
|
706 |
+
return ""
|
707 |
+
|
708 |
+
|
709 |
+
def get_default_frameworks():
|
710 |
+
"""
|
711 |
+
Returns the list of frameworks (PyTorch, TensorFlow, Flax) that are installed in the environment.
|
712 |
+
"""
|
713 |
+
frameworks = []
|
714 |
+
if is_torch_available():
|
715 |
+
frameworks.append("pt")
|
716 |
+
if is_tf_available():
|
717 |
+
frameworks.append("tf")
|
718 |
+
if is_flax_available():
|
719 |
+
frameworks.append("flax")
|
720 |
+
return frameworks
|
721 |
+
|
722 |
+
|
723 |
+
_re_model_mapping = re.compile("MODEL_([A-Z_]*)MAPPING_NAMES")
|
724 |
+
|
725 |
+
|
726 |
+
def retrieve_model_classes(model_type: str, frameworks: Optional[List[str]] = None) -> Dict[str, List[str]]:
|
727 |
+
"""
|
728 |
+
Retrieve the model classes associated to a given model.
|
729 |
+
|
730 |
+
Args:
|
731 |
+
model_type (`str`): A valid model type (like "bert" or "gpt2")
|
732 |
+
frameworks (`List[str]`, *optional*):
|
733 |
+
The frameworks to look for. Will default to `["pt", "tf", "flax"]`, passing a smaller list will restrict
|
734 |
+
the classes returned.
|
735 |
+
|
736 |
+
Returns:
|
737 |
+
`Dict[str, List[str]]`: A dictionary with one key per framework and the list of model classes associated to
|
738 |
+
that framework as values.
|
739 |
+
"""
|
740 |
+
if frameworks is None:
|
741 |
+
frameworks = get_default_frameworks()
|
742 |
+
|
743 |
+
modules = {
|
744 |
+
"pt": auto_module.modeling_auto if is_torch_available() else None,
|
745 |
+
"tf": auto_module.modeling_tf_auto if is_tf_available() else None,
|
746 |
+
"flax": auto_module.modeling_flax_auto if is_flax_available() else None,
|
747 |
+
}
|
748 |
+
|
749 |
+
model_classes = {}
|
750 |
+
for framework in frameworks:
|
751 |
+
new_model_classes = []
|
752 |
+
if modules[framework] is None:
|
753 |
+
raise ValueError(f"You selected {framework} in the frameworks, but it is not installed.")
|
754 |
+
model_mappings = [attr for attr in dir(modules[framework]) if _re_model_mapping.search(attr) is not None]
|
755 |
+
for model_mapping_name in model_mappings:
|
756 |
+
model_mapping = getattr(modules[framework], model_mapping_name)
|
757 |
+
if model_type in model_mapping:
|
758 |
+
new_model_classes.append(model_mapping[model_type])
|
759 |
+
|
760 |
+
if len(new_model_classes) > 0:
|
761 |
+
# Remove duplicates
|
762 |
+
model_classes[framework] = list(set(new_model_classes))
|
763 |
+
|
764 |
+
return model_classes
|
765 |
+
|
766 |
+
|
767 |
+
def retrieve_info_for_model(model_type, frameworks: Optional[List[str]] = None):
|
768 |
+
"""
|
769 |
+
Retrieves all the information from a given model_type.
|
770 |
+
|
771 |
+
Args:
|
772 |
+
model_type (`str`): A valid model type (like "bert" or "gpt2")
|
773 |
+
frameworks (`List[str]`, *optional*):
|
774 |
+
If passed, will only keep the info corresponding to the passed frameworks.
|
775 |
+
|
776 |
+
Returns:
|
777 |
+
`Dict`: A dictionary with the following keys:
|
778 |
+
- **frameworks** (`List[str]`): The list of frameworks that back this model type.
|
779 |
+
- **model_classes** (`Dict[str, List[str]]`): The model classes implemented for that model type.
|
780 |
+
- **model_files** (`Dict[str, Union[Path, List[Path]]]`): The files associated with that model type.
|
781 |
+
- **model_patterns** (`ModelPatterns`): The various patterns for the model.
|
782 |
+
"""
|
783 |
+
if model_type not in auto_module.MODEL_NAMES_MAPPING:
|
784 |
+
raise ValueError(f"{model_type} is not a valid model type.")
|
785 |
+
|
786 |
+
model_name = auto_module.MODEL_NAMES_MAPPING[model_type]
|
787 |
+
config_class = auto_module.configuration_auto.CONFIG_MAPPING_NAMES[model_type]
|
788 |
+
archive_map = auto_module.configuration_auto.CONFIG_ARCHIVE_MAP_MAPPING_NAMES.get(model_type, None)
|
789 |
+
if model_type in auto_module.tokenization_auto.TOKENIZER_MAPPING_NAMES:
|
790 |
+
tokenizer_classes = auto_module.tokenization_auto.TOKENIZER_MAPPING_NAMES[model_type]
|
791 |
+
tokenizer_class = tokenizer_classes[0] if tokenizer_classes[0] is not None else tokenizer_classes[1]
|
792 |
+
else:
|
793 |
+
tokenizer_class = None
|
794 |
+
image_processor_class = auto_module.image_processing_auto.IMAGE_PROCESSOR_MAPPING_NAMES.get(model_type, None)
|
795 |
+
feature_extractor_class = auto_module.feature_extraction_auto.FEATURE_EXTRACTOR_MAPPING_NAMES.get(model_type, None)
|
796 |
+
processor_class = auto_module.processing_auto.PROCESSOR_MAPPING_NAMES.get(model_type, None)
|
797 |
+
|
798 |
+
model_files = get_model_files(model_type, frameworks=frameworks)
|
799 |
+
model_camel_cased = config_class.replace("Config", "")
|
800 |
+
|
801 |
+
available_frameworks = []
|
802 |
+
for fname in model_files["model_files"]:
|
803 |
+
if "modeling_tf" in str(fname):
|
804 |
+
available_frameworks.append("tf")
|
805 |
+
elif "modeling_flax" in str(fname):
|
806 |
+
available_frameworks.append("flax")
|
807 |
+
elif "modeling" in str(fname):
|
808 |
+
available_frameworks.append("pt")
|
809 |
+
|
810 |
+
if frameworks is None:
|
811 |
+
frameworks = get_default_frameworks()
|
812 |
+
|
813 |
+
frameworks = [f for f in frameworks if f in available_frameworks]
|
814 |
+
|
815 |
+
model_classes = retrieve_model_classes(model_type, frameworks=frameworks)
|
816 |
+
|
817 |
+
# Retrieve model upper-cased name from the constant name of the pretrained archive map.
|
818 |
+
if archive_map is None:
|
819 |
+
model_upper_cased = model_camel_cased.upper()
|
820 |
+
else:
|
821 |
+
parts = archive_map.split("_")
|
822 |
+
idx = 0
|
823 |
+
while idx < len(parts) and parts[idx] != "PRETRAINED":
|
824 |
+
idx += 1
|
825 |
+
if idx < len(parts):
|
826 |
+
model_upper_cased = "_".join(parts[:idx])
|
827 |
+
else:
|
828 |
+
model_upper_cased = model_camel_cased.upper()
|
829 |
+
|
830 |
+
model_patterns = ModelPatterns(
|
831 |
+
model_name,
|
832 |
+
checkpoint=find_base_model_checkpoint(model_type, model_files=model_files),
|
833 |
+
model_type=model_type,
|
834 |
+
model_camel_cased=model_camel_cased,
|
835 |
+
model_lower_cased=model_files["module_name"],
|
836 |
+
model_upper_cased=model_upper_cased,
|
837 |
+
config_class=config_class,
|
838 |
+
tokenizer_class=tokenizer_class,
|
839 |
+
image_processor_class=image_processor_class,
|
840 |
+
feature_extractor_class=feature_extractor_class,
|
841 |
+
processor_class=processor_class,
|
842 |
+
)
|
843 |
+
|
844 |
+
return {
|
845 |
+
"frameworks": frameworks,
|
846 |
+
"model_classes": model_classes,
|
847 |
+
"model_files": model_files,
|
848 |
+
"model_patterns": model_patterns,
|
849 |
+
}
|
850 |
+
|
851 |
+
|
852 |
+
def clean_frameworks_in_init(
|
853 |
+
init_file: Union[str, os.PathLike], frameworks: Optional[List[str]] = None, keep_processing: bool = True
|
854 |
+
):
|
855 |
+
"""
|
856 |
+
Removes all the import lines that don't belong to a given list of frameworks or concern tokenizers/feature
|
857 |
+
extractors/image processors/processors in an init.
|
858 |
+
|
859 |
+
Args:
|
860 |
+
init_file (`str` or `os.PathLike`): The path to the init to treat.
|
861 |
+
frameworks (`List[str]`, *optional*):
|
862 |
+
If passed, this will remove all imports that are subject to a framework not in frameworks
|
863 |
+
keep_processing (`bool`, *optional*, defaults to `True`):
|
864 |
+
Whether or not to keep the preprocessing (tokenizer, feature extractor, image processor, processor) imports
|
865 |
+
in the init.
|
866 |
+
"""
|
867 |
+
if frameworks is None:
|
868 |
+
frameworks = get_default_frameworks()
|
869 |
+
|
870 |
+
names = {"pt": "torch"}
|
871 |
+
to_remove = [names.get(f, f) for f in ["pt", "tf", "flax"] if f not in frameworks]
|
872 |
+
if not keep_processing:
|
873 |
+
to_remove.extend(["sentencepiece", "tokenizers", "vision"])
|
874 |
+
|
875 |
+
if len(to_remove) == 0:
|
876 |
+
# Nothing to do
|
877 |
+
return
|
878 |
+
|
879 |
+
remove_pattern = "|".join(to_remove)
|
880 |
+
re_conditional_imports = re.compile(rf"^\s*if not is_({remove_pattern})_available\(\):\s*$")
|
881 |
+
re_try = re.compile(r"\s*try:")
|
882 |
+
re_else = re.compile(r"\s*else:")
|
883 |
+
re_is_xxx_available = re.compile(rf"is_({remove_pattern})_available")
|
884 |
+
|
885 |
+
with open(init_file, "r", encoding="utf-8") as f:
|
886 |
+
content = f.read()
|
887 |
+
|
888 |
+
lines = content.split("\n")
|
889 |
+
new_lines = []
|
890 |
+
idx = 0
|
891 |
+
while idx < len(lines):
|
892 |
+
# Conditional imports in try-except-else blocks
|
893 |
+
if (re_conditional_imports.search(lines[idx]) is not None) and (re_try.search(lines[idx - 1]) is not None):
|
894 |
+
# Remove the preceding `try:`
|
895 |
+
new_lines.pop()
|
896 |
+
idx += 1
|
897 |
+
# Iterate until `else:`
|
898 |
+
while is_empty_line(lines[idx]) or re_else.search(lines[idx]) is None:
|
899 |
+
idx += 1
|
900 |
+
idx += 1
|
901 |
+
indent = find_indent(lines[idx])
|
902 |
+
while find_indent(lines[idx]) >= indent or is_empty_line(lines[idx]):
|
903 |
+
idx += 1
|
904 |
+
# Remove the import from utils
|
905 |
+
elif re_is_xxx_available.search(lines[idx]) is not None:
|
906 |
+
line = lines[idx]
|
907 |
+
for framework in to_remove:
|
908 |
+
line = line.replace(f", is_{framework}_available", "")
|
909 |
+
line = line.replace(f"is_{framework}_available, ", "")
|
910 |
+
line = line.replace(f"is_{framework}_available,", "")
|
911 |
+
line = line.replace(f"is_{framework}_available", "")
|
912 |
+
|
913 |
+
if len(line.strip()) > 0:
|
914 |
+
new_lines.append(line)
|
915 |
+
idx += 1
|
916 |
+
# Otherwise we keep the line, except if it's a tokenizer import and we don't want to keep it.
|
917 |
+
elif keep_processing or (
|
918 |
+
re.search(r'^\s*"(tokenization|processing|feature_extraction|image_processing)', lines[idx]) is None
|
919 |
+
and re.search(r"^\s*from .(tokenization|processing|feature_extraction|image_processing)", lines[idx])
|
920 |
+
is None
|
921 |
+
):
|
922 |
+
new_lines.append(lines[idx])
|
923 |
+
idx += 1
|
924 |
+
else:
|
925 |
+
idx += 1
|
926 |
+
|
927 |
+
with open(init_file, "w", encoding="utf-8") as f:
|
928 |
+
f.write("\n".join(new_lines))
|
929 |
+
|
930 |
+
|
931 |
+
def add_model_to_main_init(
|
932 |
+
old_model_patterns: ModelPatterns,
|
933 |
+
new_model_patterns: ModelPatterns,
|
934 |
+
frameworks: Optional[List[str]] = None,
|
935 |
+
with_processing: bool = True,
|
936 |
+
):
|
937 |
+
"""
|
938 |
+
Add a model to the main init of Transformers.
|
939 |
+
|
940 |
+
Args:
|
941 |
+
old_model_patterns (`ModelPatterns`): The patterns for the old model.
|
942 |
+
new_model_patterns (`ModelPatterns`): The patterns for the new model.
|
943 |
+
frameworks (`List[str]`, *optional*):
|
944 |
+
If specified, only the models implemented in those frameworks will be added.
|
945 |
+
with_processsing (`bool`, *optional*, defaults to `True`):
|
946 |
+
Whether the tokenizer/feature extractor/processor of the model should also be added to the init or not.
|
947 |
+
"""
|
948 |
+
with open(TRANSFORMERS_PATH / "__init__.py", "r", encoding="utf-8") as f:
|
949 |
+
content = f.read()
|
950 |
+
|
951 |
+
lines = content.split("\n")
|
952 |
+
idx = 0
|
953 |
+
new_lines = []
|
954 |
+
framework = None
|
955 |
+
while idx < len(lines):
|
956 |
+
new_framework = False
|
957 |
+
if not is_empty_line(lines[idx]) and find_indent(lines[idx]) == 0:
|
958 |
+
framework = None
|
959 |
+
elif lines[idx].lstrip().startswith("if not is_torch_available"):
|
960 |
+
framework = "pt"
|
961 |
+
new_framework = True
|
962 |
+
elif lines[idx].lstrip().startswith("if not is_tf_available"):
|
963 |
+
framework = "tf"
|
964 |
+
new_framework = True
|
965 |
+
elif lines[idx].lstrip().startswith("if not is_flax_available"):
|
966 |
+
framework = "flax"
|
967 |
+
new_framework = True
|
968 |
+
|
969 |
+
if new_framework:
|
970 |
+
# For a new framework, we need to skip until the else: block to get where the imports are.
|
971 |
+
while lines[idx].strip() != "else:":
|
972 |
+
new_lines.append(lines[idx])
|
973 |
+
idx += 1
|
974 |
+
|
975 |
+
# Skip if we are in a framework not wanted.
|
976 |
+
if framework is not None and frameworks is not None and framework not in frameworks:
|
977 |
+
new_lines.append(lines[idx])
|
978 |
+
idx += 1
|
979 |
+
elif re.search(rf'models.{old_model_patterns.model_lower_cased}( |")', lines[idx]) is not None:
|
980 |
+
block = [lines[idx]]
|
981 |
+
indent = find_indent(lines[idx])
|
982 |
+
idx += 1
|
983 |
+
while find_indent(lines[idx]) > indent:
|
984 |
+
block.append(lines[idx])
|
985 |
+
idx += 1
|
986 |
+
if lines[idx].strip() in [")", "]", "],"]:
|
987 |
+
block.append(lines[idx])
|
988 |
+
idx += 1
|
989 |
+
block = "\n".join(block)
|
990 |
+
new_lines.append(block)
|
991 |
+
|
992 |
+
add_block = True
|
993 |
+
if not with_processing:
|
994 |
+
processing_classes = [
|
995 |
+
old_model_patterns.tokenizer_class,
|
996 |
+
old_model_patterns.image_processor_class,
|
997 |
+
old_model_patterns.feature_extractor_class,
|
998 |
+
old_model_patterns.processor_class,
|
999 |
+
]
|
1000 |
+
# Only keep the ones that are not None
|
1001 |
+
processing_classes = [c for c in processing_classes if c is not None]
|
1002 |
+
for processing_class in processing_classes:
|
1003 |
+
block = block.replace(f' "{processing_class}",', "")
|
1004 |
+
block = block.replace(f', "{processing_class}"', "")
|
1005 |
+
block = block.replace(f" {processing_class},", "")
|
1006 |
+
block = block.replace(f", {processing_class}", "")
|
1007 |
+
|
1008 |
+
if processing_class in block:
|
1009 |
+
add_block = False
|
1010 |
+
if add_block:
|
1011 |
+
new_lines.append(replace_model_patterns(block, old_model_patterns, new_model_patterns)[0])
|
1012 |
+
else:
|
1013 |
+
new_lines.append(lines[idx])
|
1014 |
+
idx += 1
|
1015 |
+
|
1016 |
+
with open(TRANSFORMERS_PATH / "__init__.py", "w", encoding="utf-8") as f:
|
1017 |
+
f.write("\n".join(new_lines))
|
1018 |
+
|
1019 |
+
|
1020 |
+
def insert_tokenizer_in_auto_module(old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns):
|
1021 |
+
"""
|
1022 |
+
Add a tokenizer to the relevant mappings in the auto module.
|
1023 |
+
|
1024 |
+
Args:
|
1025 |
+
old_model_patterns (`ModelPatterns`): The patterns for the old model.
|
1026 |
+
new_model_patterns (`ModelPatterns`): The patterns for the new model.
|
1027 |
+
"""
|
1028 |
+
if old_model_patterns.tokenizer_class is None or new_model_patterns.tokenizer_class is None:
|
1029 |
+
return
|
1030 |
+
|
1031 |
+
with open(TRANSFORMERS_PATH / "models" / "auto" / "tokenization_auto.py", "r", encoding="utf-8") as f:
|
1032 |
+
content = f.read()
|
1033 |
+
|
1034 |
+
lines = content.split("\n")
|
1035 |
+
idx = 0
|
1036 |
+
# First we get to the TOKENIZER_MAPPING_NAMES block.
|
1037 |
+
while not lines[idx].startswith(" TOKENIZER_MAPPING_NAMES = OrderedDict("):
|
1038 |
+
idx += 1
|
1039 |
+
idx += 1
|
1040 |
+
|
1041 |
+
# That block will end at this prompt:
|
1042 |
+
while not lines[idx].startswith("TOKENIZER_MAPPING = _LazyAutoMapping"):
|
1043 |
+
# Either all the tokenizer block is defined on one line, in which case, it ends with "),"
|
1044 |
+
if lines[idx].endswith(","):
|
1045 |
+
block = lines[idx]
|
1046 |
+
# Otherwise it takes several lines until we get to a "),"
|
1047 |
+
else:
|
1048 |
+
block = []
|
1049 |
+
while not lines[idx].startswith(" ),"):
|
1050 |
+
block.append(lines[idx])
|
1051 |
+
idx += 1
|
1052 |
+
block = "\n".join(block)
|
1053 |
+
idx += 1
|
1054 |
+
|
1055 |
+
# If we find the model type and tokenizer class in that block, we have the old model tokenizer block
|
1056 |
+
if f'"{old_model_patterns.model_type}"' in block and old_model_patterns.tokenizer_class in block:
|
1057 |
+
break
|
1058 |
+
|
1059 |
+
new_block = block.replace(old_model_patterns.model_type, new_model_patterns.model_type)
|
1060 |
+
new_block = new_block.replace(old_model_patterns.tokenizer_class, new_model_patterns.tokenizer_class)
|
1061 |
+
|
1062 |
+
new_lines = lines[:idx] + [new_block] + lines[idx:]
|
1063 |
+
with open(TRANSFORMERS_PATH / "models" / "auto" / "tokenization_auto.py", "w", encoding="utf-8") as f:
|
1064 |
+
f.write("\n".join(new_lines))
|
1065 |
+
|
1066 |
+
|
1067 |
+
AUTO_CLASSES_PATTERNS = {
|
1068 |
+
"configuration_auto.py": [
|
1069 |
+
' ("{model_type}", "{model_name}"),',
|
1070 |
+
' ("{model_type}", "{config_class}"),',
|
1071 |
+
' ("{model_type}", "{pretrained_archive_map}"),',
|
1072 |
+
],
|
1073 |
+
"feature_extraction_auto.py": [' ("{model_type}", "{feature_extractor_class}"),'],
|
1074 |
+
"image_processing_auto.py": [' ("{model_type}", "{image_processor_class}"),'],
|
1075 |
+
"modeling_auto.py": [' ("{model_type}", "{any_pt_class}"),'],
|
1076 |
+
"modeling_tf_auto.py": [' ("{model_type}", "{any_tf_class}"),'],
|
1077 |
+
"modeling_flax_auto.py": [' ("{model_type}", "{any_flax_class}"),'],
|
1078 |
+
"processing_auto.py": [' ("{model_type}", "{processor_class}"),'],
|
1079 |
+
}
|
1080 |
+
|
1081 |
+
|
1082 |
+
def add_model_to_auto_classes(
|
1083 |
+
old_model_patterns: ModelPatterns, new_model_patterns: ModelPatterns, model_classes: Dict[str, List[str]]
|
1084 |
+
):
|
1085 |
+
"""
|
1086 |
+
Add a model to the relevant mappings in the auto module.
|
1087 |
+
|
1088 |
+
Args:
|
1089 |
+
old_model_patterns (`ModelPatterns`): The patterns for the old model.
|
1090 |
+
new_model_patterns (`ModelPatterns`): The patterns for the new model.
|
1091 |
+
model_classes (`Dict[str, List[str]]`): A dictionary framework to list of model classes implemented.
|
1092 |
+
"""
|
1093 |
+
for filename in AUTO_CLASSES_PATTERNS:
|
1094 |
+
# Extend patterns with all model classes if necessary
|
1095 |
+
new_patterns = []
|
1096 |
+
for pattern in AUTO_CLASSES_PATTERNS[filename]:
|
1097 |
+
if re.search("any_([a-z]*)_class", pattern) is not None:
|
1098 |
+
framework = re.search("any_([a-z]*)_class", pattern).groups()[0]
|
1099 |
+
if framework in model_classes:
|
1100 |
+
new_patterns.extend(
|
1101 |
+
[
|
1102 |
+
pattern.replace("{" + f"any_{framework}_class" + "}", cls)
|
1103 |
+
for cls in model_classes[framework]
|
1104 |
+
]
|
1105 |
+
)
|
1106 |
+
elif "{config_class}" in pattern:
|
1107 |
+
new_patterns.append(pattern.replace("{config_class}", old_model_patterns.config_class))
|
1108 |
+
elif "{image_processor_class}" in pattern:
|
1109 |
+
if (
|
1110 |
+
old_model_patterns.image_processor_class is not None
|
1111 |
+
and new_model_patterns.image_processor_class is not None
|
1112 |
+
):
|
1113 |
+
new_patterns.append(
|
1114 |
+
pattern.replace("{image_processor_class}", old_model_patterns.image_processor_class)
|
1115 |
+
)
|
1116 |
+
elif "{feature_extractor_class}" in pattern:
|
1117 |
+
if (
|
1118 |
+
old_model_patterns.feature_extractor_class is not None
|
1119 |
+
and new_model_patterns.feature_extractor_class is not None
|
1120 |
+
):
|
1121 |
+
new_patterns.append(
|
1122 |
+
pattern.replace("{feature_extractor_class}", old_model_patterns.feature_extractor_class)
|
1123 |
+
)
|
1124 |
+
elif "{processor_class}" in pattern:
|
1125 |
+
if old_model_patterns.processor_class is not None and new_model_patterns.processor_class is not None:
|
1126 |
+
new_patterns.append(pattern.replace("{processor_class}", old_model_patterns.processor_class))
|
1127 |
+
else:
|
1128 |
+
new_patterns.append(pattern)
|
1129 |
+
|
1130 |
+
# Loop through all patterns.
|
1131 |
+
for pattern in new_patterns:
|
1132 |
+
full_name = TRANSFORMERS_PATH / "models" / "auto" / filename
|
1133 |
+
old_model_line = pattern
|
1134 |
+
new_model_line = pattern
|
1135 |
+
for attr in ["model_type", "model_name"]:
|
1136 |
+
old_model_line = old_model_line.replace("{" + attr + "}", getattr(old_model_patterns, attr))
|
1137 |
+
new_model_line = new_model_line.replace("{" + attr + "}", getattr(new_model_patterns, attr))
|
1138 |
+
if "pretrained_archive_map" in pattern:
|
1139 |
+
old_model_line = old_model_line.replace(
|
1140 |
+
"{pretrained_archive_map}", f"{old_model_patterns.model_upper_cased}_PRETRAINED_CONFIG_ARCHIVE_MAP"
|
1141 |
+
)
|
1142 |
+
new_model_line = new_model_line.replace(
|
1143 |
+
"{pretrained_archive_map}", f"{new_model_patterns.model_upper_cased}_PRETRAINED_CONFIG_ARCHIVE_MAP"
|
1144 |
+
)
|
1145 |
+
|
1146 |
+
new_model_line = new_model_line.replace(
|
1147 |
+
old_model_patterns.model_camel_cased, new_model_patterns.model_camel_cased
|
1148 |
+
)
|
1149 |
+
|
1150 |
+
add_content_to_file(full_name, new_model_line, add_after=old_model_line)
|
1151 |
+
|
1152 |
+
# Tokenizers require special handling
|
1153 |
+
insert_tokenizer_in_auto_module(old_model_patterns, new_model_patterns)
|
1154 |
+
|
1155 |
+
|
1156 |
+
DOC_OVERVIEW_TEMPLATE = """## Overview
|
1157 |
+
|
1158 |
+
The {model_name} model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
|
1159 |
+
<INSERT SHORT SUMMARY HERE>
|
1160 |
+
|
1161 |
+
The abstract from the paper is the following:
|
1162 |
+
|
1163 |
+
*<INSERT PAPER ABSTRACT HERE>*
|
1164 |
+
|
1165 |
+
Tips:
|
1166 |
+
|
1167 |
+
<INSERT TIPS ABOUT MODEL HERE>
|
1168 |
+
|
1169 |
+
This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
|
1170 |
+
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
|
1171 |
+
|
1172 |
+
"""
|
1173 |
+
|
1174 |
+
|
1175 |
+
def duplicate_doc_file(
|
1176 |
+
doc_file: Union[str, os.PathLike],
|
1177 |
+
old_model_patterns: ModelPatterns,
|
1178 |
+
new_model_patterns: ModelPatterns,
|
1179 |
+
dest_file: Optional[Union[str, os.PathLike]] = None,
|
1180 |
+
frameworks: Optional[List[str]] = None,
|
1181 |
+
):
|
1182 |
+
"""
|
1183 |
+
Duplicate a documentation file and adapts it for a new model.
|
1184 |
+
|
1185 |
+
Args:
|
1186 |
+
module_file (`str` or `os.PathLike`): Path to the doc file to duplicate.
|
1187 |
+
old_model_patterns (`ModelPatterns`): The patterns for the old model.
|
1188 |
+
new_model_patterns (`ModelPatterns`): The patterns for the new model.
|
1189 |
+
dest_file (`str` or `os.PathLike`, *optional*): Path to the new doc file.
|
1190 |
+
Will default to the a file named `{new_model_patterns.model_type}.md` in the same folder as `module_file`.
|
1191 |
+
frameworks (`List[str]`, *optional*):
|
1192 |
+
If passed, will only keep the model classes corresponding to this list of frameworks in the new doc file.
|
1193 |
+
"""
|
1194 |
+
with open(doc_file, "r", encoding="utf-8") as f:
|
1195 |
+
content = f.read()
|
1196 |
+
|
1197 |
+
content = re.sub(r"<!--\s*Copyright (\d+)\s", f"<!--Copyright {CURRENT_YEAR} ", content)
|
1198 |
+
if frameworks is None:
|
1199 |
+
frameworks = get_default_frameworks()
|
1200 |
+
if dest_file is None:
|
1201 |
+
dest_file = Path(doc_file).parent / f"{new_model_patterns.model_type}.md"
|
1202 |
+
|
1203 |
+
# Parse the doc file in blocks. One block per section/header
|
1204 |
+
lines = content.split("\n")
|
1205 |
+
blocks = []
|
1206 |
+
current_block = []
|
1207 |
+
|
1208 |
+
for line in lines:
|
1209 |
+
if line.startswith("#"):
|
1210 |
+
blocks.append("\n".join(current_block))
|
1211 |
+
current_block = [line]
|
1212 |
+
else:
|
1213 |
+
current_block.append(line)
|
1214 |
+
blocks.append("\n".join(current_block))
|
1215 |
+
|
1216 |
+
new_blocks = []
|
1217 |
+
in_classes = False
|
1218 |
+
for block in blocks:
|
1219 |
+
# Copyright
|
1220 |
+
if not block.startswith("#"):
|
1221 |
+
new_blocks.append(block)
|
1222 |
+
# Main title
|
1223 |
+
elif re.search(r"^#\s+\S+", block) is not None:
|
1224 |
+
new_blocks.append(f"# {new_model_patterns.model_name}\n")
|
1225 |
+
# The config starts the part of the doc with the classes.
|
1226 |
+
elif not in_classes and old_model_patterns.config_class in block.split("\n")[0]:
|
1227 |
+
in_classes = True
|
1228 |
+
new_blocks.append(DOC_OVERVIEW_TEMPLATE.format(model_name=new_model_patterns.model_name))
|
1229 |
+
new_block, _ = replace_model_patterns(block, old_model_patterns, new_model_patterns)
|
1230 |
+
new_blocks.append(new_block)
|
1231 |
+
# In classes
|
1232 |
+
elif in_classes:
|
1233 |
+
in_classes = True
|
1234 |
+
block_title = block.split("\n")[0]
|
1235 |
+
block_class = re.search(r"^#+\s+(\S.*)$", block_title).groups()[0]
|
1236 |
+
new_block, _ = replace_model_patterns(block, old_model_patterns, new_model_patterns)
|
1237 |
+
|
1238 |
+
if "Tokenizer" in block_class:
|
1239 |
+
# We only add the tokenizer if necessary
|
1240 |
+
if old_model_patterns.tokenizer_class != new_model_patterns.tokenizer_class:
|
1241 |
+
new_blocks.append(new_block)
|
1242 |
+
elif "ImageProcessor" in block_class:
|
1243 |
+
# We only add the image processor if necessary
|
1244 |
+
if old_model_patterns.image_processor_class != new_model_patterns.image_processor_class:
|
1245 |
+
new_blocks.append(new_block)
|
1246 |
+
elif "FeatureExtractor" in block_class:
|
1247 |
+
# We only add the feature extractor if necessary
|
1248 |
+
if old_model_patterns.feature_extractor_class != new_model_patterns.feature_extractor_class:
|
1249 |
+
new_blocks.append(new_block)
|
1250 |
+
elif "Processor" in block_class:
|
1251 |
+
# We only add the processor if necessary
|
1252 |
+
if old_model_patterns.processor_class != new_model_patterns.processor_class:
|
1253 |
+
new_blocks.append(new_block)
|
1254 |
+
elif block_class.startswith("Flax"):
|
1255 |
+
# We only add Flax models if in the selected frameworks
|
1256 |
+
if "flax" in frameworks:
|
1257 |
+
new_blocks.append(new_block)
|
1258 |
+
elif block_class.startswith("TF"):
|
1259 |
+
# We only add TF models if in the selected frameworks
|
1260 |
+
if "tf" in frameworks:
|
1261 |
+
new_blocks.append(new_block)
|
1262 |
+
elif len(block_class.split(" ")) == 1:
|
1263 |
+
# We only add PyTorch models if in the selected frameworks
|
1264 |
+
if "pt" in frameworks:
|
1265 |
+
new_blocks.append(new_block)
|
1266 |
+
else:
|
1267 |
+
new_blocks.append(new_block)
|
1268 |
+
|
1269 |
+
with open(dest_file, "w", encoding="utf-8") as f:
|
1270 |
+
f.write("\n".join(new_blocks))
|
1271 |
+
|
1272 |
+
|
1273 |
+
def insert_model_in_doc_toc(old_model_patterns, new_model_patterns):
|
1274 |
+
"""
|
1275 |
+
Insert the new model in the doc TOC, in the same section as the old model.
|
1276 |
+
|
1277 |
+
Args:
|
1278 |
+
old_model_patterns (`ModelPatterns`): The patterns for the old model.
|
1279 |
+
new_model_patterns (`ModelPatterns`): The patterns for the new model.
|
1280 |
+
"""
|
1281 |
+
toc_file = REPO_PATH / "docs" / "source" / "en" / "_toctree.yml"
|
1282 |
+
with open(toc_file, "r", encoding="utf8") as f:
|
1283 |
+
content = yaml.safe_load(f)
|
1284 |
+
|
1285 |
+
# Get to the model API doc
|
1286 |
+
api_idx = 0
|
1287 |
+
while content[api_idx]["title"] != "API":
|
1288 |
+
api_idx += 1
|
1289 |
+
api_doc = content[api_idx]["sections"]
|
1290 |
+
|
1291 |
+
model_idx = 0
|
1292 |
+
while api_doc[model_idx]["title"] != "Models":
|
1293 |
+
model_idx += 1
|
1294 |
+
model_doc = api_doc[model_idx]["sections"]
|
1295 |
+
|
1296 |
+
# Find the base model in the Toc
|
1297 |
+
old_model_type = old_model_patterns.model_type
|
1298 |
+
section_idx = 0
|
1299 |
+
while section_idx < len(model_doc):
|
1300 |
+
sections = [entry["local"] for entry in model_doc[section_idx]["sections"]]
|
1301 |
+
if f"model_doc/{old_model_type}" in sections:
|
1302 |
+
break
|
1303 |
+
|
1304 |
+
section_idx += 1
|
1305 |
+
|
1306 |
+
if section_idx == len(model_doc):
|
1307 |
+
old_model = old_model_patterns.model_name
|
1308 |
+
new_model = new_model_patterns.model_name
|
1309 |
+
print(f"Did not find {old_model} in the table of content, so you will need to add {new_model} manually.")
|
1310 |
+
return
|
1311 |
+
|
1312 |
+
# Add the new model in the same toc
|
1313 |
+
toc_entry = {"local": f"model_doc/{new_model_patterns.model_type}", "title": new_model_patterns.model_name}
|
1314 |
+
model_doc[section_idx]["sections"].append(toc_entry)
|
1315 |
+
model_doc[section_idx]["sections"] = sorted(model_doc[section_idx]["sections"], key=lambda s: s["title"].lower())
|
1316 |
+
api_doc[model_idx]["sections"] = model_doc
|
1317 |
+
content[api_idx]["sections"] = api_doc
|
1318 |
+
|
1319 |
+
with open(toc_file, "w", encoding="utf-8") as f:
|
1320 |
+
f.write(yaml.dump(content, allow_unicode=True))
|
1321 |
+
|
1322 |
+
|
1323 |
+
def create_new_model_like(
|
1324 |
+
model_type: str,
|
1325 |
+
new_model_patterns: ModelPatterns,
|
1326 |
+
add_copied_from: bool = True,
|
1327 |
+
frameworks: Optional[List[str]] = None,
|
1328 |
+
old_checkpoint: Optional[str] = None,
|
1329 |
+
):
|
1330 |
+
"""
|
1331 |
+
Creates a new model module like a given model of the Transformers library.
|
1332 |
+
|
1333 |
+
Args:
|
1334 |
+
model_type (`str`): The model type to duplicate (like "bert" or "gpt2")
|
1335 |
+
new_model_patterns (`ModelPatterns`): The patterns for the new model.
|
1336 |
+
add_copied_from (`bool`, *optional*, defaults to `True`):
|
1337 |
+
Whether or not to add "Copied from" statements to all classes in the new model modeling files.
|
1338 |
+
frameworks (`List[str]`, *optional*):
|
1339 |
+
If passed, will limit the duplicate to the frameworks specified.
|
1340 |
+
old_checkpoint (`str`, *optional*):
|
1341 |
+
The name of the base checkpoint for the old model. Should be passed along when it can't be automatically
|
1342 |
+
recovered from the `model_type`.
|
1343 |
+
"""
|
1344 |
+
# Retrieve all the old model info.
|
1345 |
+
model_info = retrieve_info_for_model(model_type, frameworks=frameworks)
|
1346 |
+
model_files = model_info["model_files"]
|
1347 |
+
old_model_patterns = model_info["model_patterns"]
|
1348 |
+
if old_checkpoint is not None:
|
1349 |
+
old_model_patterns.checkpoint = old_checkpoint
|
1350 |
+
if len(old_model_patterns.checkpoint) == 0:
|
1351 |
+
raise ValueError(
|
1352 |
+
"The old model checkpoint could not be recovered from the model type. Please pass it to the "
|
1353 |
+
"`old_checkpoint` argument."
|
1354 |
+
)
|
1355 |
+
|
1356 |
+
keep_old_processing = True
|
1357 |
+
for processing_attr in ["image_processor_class", "feature_extractor_class", "processor_class", "tokenizer_class"]:
|
1358 |
+
if getattr(old_model_patterns, processing_attr) != getattr(new_model_patterns, processing_attr):
|
1359 |
+
keep_old_processing = False
|
1360 |
+
|
1361 |
+
model_classes = model_info["model_classes"]
|
1362 |
+
|
1363 |
+
# 1. We create the module for our new model.
|
1364 |
+
old_module_name = model_files["module_name"]
|
1365 |
+
module_folder = TRANSFORMERS_PATH / "models" / new_model_patterns.model_lower_cased
|
1366 |
+
os.makedirs(module_folder, exist_ok=True)
|
1367 |
+
|
1368 |
+
files_to_adapt = model_files["model_files"]
|
1369 |
+
if keep_old_processing:
|
1370 |
+
files_to_adapt = [
|
1371 |
+
f
|
1372 |
+
for f in files_to_adapt
|
1373 |
+
if "tokenization" not in str(f)
|
1374 |
+
and "processing" not in str(f)
|
1375 |
+
and "feature_extraction" not in str(f)
|
1376 |
+
and "image_processing" not in str(f)
|
1377 |
+
]
|
1378 |
+
|
1379 |
+
os.makedirs(module_folder, exist_ok=True)
|
1380 |
+
for module_file in files_to_adapt:
|
1381 |
+
new_module_name = module_file.name.replace(
|
1382 |
+
old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
|
1383 |
+
)
|
1384 |
+
dest_file = module_folder / new_module_name
|
1385 |
+
duplicate_module(
|
1386 |
+
module_file,
|
1387 |
+
old_model_patterns,
|
1388 |
+
new_model_patterns,
|
1389 |
+
dest_file=dest_file,
|
1390 |
+
add_copied_from=add_copied_from and "modeling" in new_module_name,
|
1391 |
+
)
|
1392 |
+
|
1393 |
+
clean_frameworks_in_init(
|
1394 |
+
module_folder / "__init__.py", frameworks=frameworks, keep_processing=not keep_old_processing
|
1395 |
+
)
|
1396 |
+
|
1397 |
+
# 2. We add our new model to the models init and the main init
|
1398 |
+
add_content_to_file(
|
1399 |
+
TRANSFORMERS_PATH / "models" / "__init__.py",
|
1400 |
+
f" {new_model_patterns.model_lower_cased},",
|
1401 |
+
add_after=f" {old_module_name},",
|
1402 |
+
exact_match=True,
|
1403 |
+
)
|
1404 |
+
add_model_to_main_init(
|
1405 |
+
old_model_patterns, new_model_patterns, frameworks=frameworks, with_processing=not keep_old_processing
|
1406 |
+
)
|
1407 |
+
|
1408 |
+
# 3. Add test files
|
1409 |
+
files_to_adapt = model_files["test_files"]
|
1410 |
+
if keep_old_processing:
|
1411 |
+
files_to_adapt = [
|
1412 |
+
f
|
1413 |
+
for f in files_to_adapt
|
1414 |
+
if "tokenization" not in str(f)
|
1415 |
+
and "processor" not in str(f)
|
1416 |
+
and "feature_extraction" not in str(f)
|
1417 |
+
and "image_processing" not in str(f)
|
1418 |
+
]
|
1419 |
+
|
1420 |
+
def disable_fx_test(filename: Path) -> bool:
|
1421 |
+
with open(filename) as fp:
|
1422 |
+
content = fp.read()
|
1423 |
+
new_content = re.sub(r"fx_compatible\s*=\s*True", "fx_compatible = False", content)
|
1424 |
+
with open(filename, "w") as fp:
|
1425 |
+
fp.write(new_content)
|
1426 |
+
return content != new_content
|
1427 |
+
|
1428 |
+
disabled_fx_test = False
|
1429 |
+
|
1430 |
+
tests_folder = REPO_PATH / "tests" / "models" / new_model_patterns.model_lower_cased
|
1431 |
+
os.makedirs(tests_folder, exist_ok=True)
|
1432 |
+
with open(tests_folder / "__init__.py", "w"):
|
1433 |
+
pass
|
1434 |
+
|
1435 |
+
for test_file in files_to_adapt:
|
1436 |
+
new_test_file_name = test_file.name.replace(
|
1437 |
+
old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
|
1438 |
+
)
|
1439 |
+
dest_file = test_file.parent.parent / new_model_patterns.model_lower_cased / new_test_file_name
|
1440 |
+
duplicate_module(
|
1441 |
+
test_file,
|
1442 |
+
old_model_patterns,
|
1443 |
+
new_model_patterns,
|
1444 |
+
dest_file=dest_file,
|
1445 |
+
add_copied_from=False,
|
1446 |
+
attrs_to_remove=["pipeline_model_mapping", "is_pipeline_test_to_skip"],
|
1447 |
+
)
|
1448 |
+
disabled_fx_test = disabled_fx_test | disable_fx_test(dest_file)
|
1449 |
+
|
1450 |
+
if disabled_fx_test:
|
1451 |
+
print(
|
1452 |
+
"The tests for symbolic tracing with torch.fx were disabled, you can add those once symbolic tracing works"
|
1453 |
+
" for your new model."
|
1454 |
+
)
|
1455 |
+
|
1456 |
+
# 4. Add model to auto classes
|
1457 |
+
add_model_to_auto_classes(old_model_patterns, new_model_patterns, model_classes)
|
1458 |
+
|
1459 |
+
# 5. Add doc file
|
1460 |
+
doc_file = REPO_PATH / "docs" / "source" / "en" / "model_doc" / f"{old_model_patterns.model_type}.md"
|
1461 |
+
duplicate_doc_file(doc_file, old_model_patterns, new_model_patterns, frameworks=frameworks)
|
1462 |
+
insert_model_in_doc_toc(old_model_patterns, new_model_patterns)
|
1463 |
+
|
1464 |
+
# 6. Warn the user for duplicate patterns
|
1465 |
+
if old_model_patterns.model_type == old_model_patterns.checkpoint:
|
1466 |
+
print(
|
1467 |
+
"The model you picked has the same name for the model type and the checkpoint name "
|
1468 |
+
f"({old_model_patterns.model_type}). As a result, it's possible some places where the new checkpoint "
|
1469 |
+
f"should be, you have {new_model_patterns.model_type} instead. You should search for all instances of "
|
1470 |
+
f"{new_model_patterns.model_type} in the new files and check they're not badly used as checkpoints."
|
1471 |
+
)
|
1472 |
+
elif old_model_patterns.model_lower_cased == old_model_patterns.checkpoint:
|
1473 |
+
print(
|
1474 |
+
"The model you picked has the same name for the model type and the checkpoint name "
|
1475 |
+
f"({old_model_patterns.model_lower_cased}). As a result, it's possible some places where the new "
|
1476 |
+
f"checkpoint should be, you have {new_model_patterns.model_lower_cased} instead. You should search for "
|
1477 |
+
f"all instances of {new_model_patterns.model_lower_cased} in the new files and check they're not badly "
|
1478 |
+
"used as checkpoints."
|
1479 |
+
)
|
1480 |
+
if (
|
1481 |
+
old_model_patterns.model_type == old_model_patterns.model_lower_cased
|
1482 |
+
and new_model_patterns.model_type != new_model_patterns.model_lower_cased
|
1483 |
+
):
|
1484 |
+
print(
|
1485 |
+
"The model you picked has the same name for the model type and the lowercased model name "
|
1486 |
+
f"({old_model_patterns.model_lower_cased}). As a result, it's possible some places where the new "
|
1487 |
+
f"model type should be, you have {new_model_patterns.model_lower_cased} instead. You should search for "
|
1488 |
+
f"all instances of {new_model_patterns.model_lower_cased} in the new files and check they're not badly "
|
1489 |
+
"used as the model type."
|
1490 |
+
)
|
1491 |
+
|
1492 |
+
if not keep_old_processing and old_model_patterns.tokenizer_class is not None:
|
1493 |
+
print(
|
1494 |
+
"The constants at the start of the new tokenizer file created needs to be manually fixed. If your new "
|
1495 |
+
"model has a tokenizer fast, you will also need to manually add the converter in the "
|
1496 |
+
"`SLOW_TO_FAST_CONVERTERS` constant of `convert_slow_tokenizer.py`."
|
1497 |
+
)
|
1498 |
+
|
1499 |
+
|
1500 |
+
def add_new_model_like_command_factory(args: Namespace):
|
1501 |
+
return AddNewModelLikeCommand(config_file=args.config_file, path_to_repo=args.path_to_repo)
|
1502 |
+
|
1503 |
+
|
1504 |
+
class AddNewModelLikeCommand(BaseTransformersCLICommand):
|
1505 |
+
@staticmethod
|
1506 |
+
def register_subcommand(parser: ArgumentParser):
|
1507 |
+
add_new_model_like_parser = parser.add_parser("add-new-model-like")
|
1508 |
+
add_new_model_like_parser.add_argument(
|
1509 |
+
"--config_file", type=str, help="A file with all the information for this model creation."
|
1510 |
+
)
|
1511 |
+
add_new_model_like_parser.add_argument(
|
1512 |
+
"--path_to_repo", type=str, help="When not using an editable install, the path to the Transformers repo."
|
1513 |
+
)
|
1514 |
+
add_new_model_like_parser.set_defaults(func=add_new_model_like_command_factory)
|
1515 |
+
|
1516 |
+
def __init__(self, config_file=None, path_to_repo=None, *args):
|
1517 |
+
if config_file is not None:
|
1518 |
+
with open(config_file, "r", encoding="utf-8") as f:
|
1519 |
+
config = json.load(f)
|
1520 |
+
self.old_model_type = config["old_model_type"]
|
1521 |
+
self.model_patterns = ModelPatterns(**config["new_model_patterns"])
|
1522 |
+
self.add_copied_from = config.get("add_copied_from", True)
|
1523 |
+
self.frameworks = config.get("frameworks", get_default_frameworks())
|
1524 |
+
self.old_checkpoint = config.get("old_checkpoint", None)
|
1525 |
+
else:
|
1526 |
+
(
|
1527 |
+
self.old_model_type,
|
1528 |
+
self.model_patterns,
|
1529 |
+
self.add_copied_from,
|
1530 |
+
self.frameworks,
|
1531 |
+
self.old_checkpoint,
|
1532 |
+
) = get_user_input()
|
1533 |
+
|
1534 |
+
self.path_to_repo = path_to_repo
|
1535 |
+
|
1536 |
+
def run(self):
|
1537 |
+
if self.path_to_repo is not None:
|
1538 |
+
# Adapt constants
|
1539 |
+
global TRANSFORMERS_PATH
|
1540 |
+
global REPO_PATH
|
1541 |
+
|
1542 |
+
REPO_PATH = Path(self.path_to_repo)
|
1543 |
+
TRANSFORMERS_PATH = REPO_PATH / "src" / "transformers"
|
1544 |
+
|
1545 |
+
create_new_model_like(
|
1546 |
+
model_type=self.old_model_type,
|
1547 |
+
new_model_patterns=self.model_patterns,
|
1548 |
+
add_copied_from=self.add_copied_from,
|
1549 |
+
frameworks=self.frameworks,
|
1550 |
+
old_checkpoint=self.old_checkpoint,
|
1551 |
+
)
|
1552 |
+
|
1553 |
+
|
1554 |
+
def get_user_field(
|
1555 |
+
question: str,
|
1556 |
+
default_value: Optional[str] = None,
|
1557 |
+
is_valid_answer: Optional[Callable] = None,
|
1558 |
+
convert_to: Optional[Callable] = None,
|
1559 |
+
fallback_message: Optional[str] = None,
|
1560 |
+
) -> Any:
|
1561 |
+
"""
|
1562 |
+
A utility function that asks a question to the user to get an answer, potentially looping until it gets a valid
|
1563 |
+
answer.
|
1564 |
+
|
1565 |
+
Args:
|
1566 |
+
question (`str`): The question to ask the user.
|
1567 |
+
default_value (`str`, *optional*): A potential default value that will be used when the answer is empty.
|
1568 |
+
is_valid_answer (`Callable`, *optional*):
|
1569 |
+
If set, the question will be asked until this function returns `True` on the provided answer.
|
1570 |
+
convert_to (`Callable`, *optional*):
|
1571 |
+
If set, the answer will be passed to this function. If this function raises an error on the procided
|
1572 |
+
answer, the question will be asked again.
|
1573 |
+
fallback_message (`str`, *optional*):
|
1574 |
+
A message that will be displayed each time the question is asked again to the user.
|
1575 |
+
|
1576 |
+
Returns:
|
1577 |
+
`Any`: The answer provided by the user (or the default), passed through the potential conversion function.
|
1578 |
+
"""
|
1579 |
+
if not question.endswith(" "):
|
1580 |
+
question = question + " "
|
1581 |
+
if default_value is not None:
|
1582 |
+
question = f"{question} [{default_value}] "
|
1583 |
+
|
1584 |
+
valid_answer = False
|
1585 |
+
while not valid_answer:
|
1586 |
+
answer = input(question)
|
1587 |
+
if default_value is not None and len(answer) == 0:
|
1588 |
+
answer = default_value
|
1589 |
+
if is_valid_answer is not None:
|
1590 |
+
valid_answer = is_valid_answer(answer)
|
1591 |
+
elif convert_to is not None:
|
1592 |
+
try:
|
1593 |
+
answer = convert_to(answer)
|
1594 |
+
valid_answer = True
|
1595 |
+
except Exception:
|
1596 |
+
valid_answer = False
|
1597 |
+
else:
|
1598 |
+
valid_answer = True
|
1599 |
+
|
1600 |
+
if not valid_answer:
|
1601 |
+
print(fallback_message)
|
1602 |
+
|
1603 |
+
return answer
|
1604 |
+
|
1605 |
+
|
1606 |
+
def convert_to_bool(x: str) -> bool:
|
1607 |
+
"""
|
1608 |
+
Converts a string to a bool.
|
1609 |
+
"""
|
1610 |
+
if x.lower() in ["1", "y", "yes", "true"]:
|
1611 |
+
return True
|
1612 |
+
if x.lower() in ["0", "n", "no", "false"]:
|
1613 |
+
return False
|
1614 |
+
raise ValueError(f"{x} is not a value that can be converted to a bool.")
|
1615 |
+
|
1616 |
+
|
1617 |
+
def get_user_input():
|
1618 |
+
"""
|
1619 |
+
Ask the user for the necessary inputs to add the new model.
|
1620 |
+
"""
|
1621 |
+
model_types = list(auto_module.configuration_auto.MODEL_NAMES_MAPPING.keys())
|
1622 |
+
|
1623 |
+
# Get old model type
|
1624 |
+
valid_model_type = False
|
1625 |
+
while not valid_model_type:
|
1626 |
+
old_model_type = input(
|
1627 |
+
"What is the model you would like to duplicate? Please provide the lowercase `model_type` (e.g. roberta): "
|
1628 |
+
)
|
1629 |
+
if old_model_type in model_types:
|
1630 |
+
valid_model_type = True
|
1631 |
+
else:
|
1632 |
+
print(f"{old_model_type} is not a valid model type.")
|
1633 |
+
near_choices = difflib.get_close_matches(old_model_type, model_types)
|
1634 |
+
if len(near_choices) >= 1:
|
1635 |
+
if len(near_choices) > 1:
|
1636 |
+
near_choices = " or ".join(near_choices)
|
1637 |
+
print(f"Did you mean {near_choices}?")
|
1638 |
+
|
1639 |
+
old_model_info = retrieve_info_for_model(old_model_type)
|
1640 |
+
old_tokenizer_class = old_model_info["model_patterns"].tokenizer_class
|
1641 |
+
old_image_processor_class = old_model_info["model_patterns"].image_processor_class
|
1642 |
+
old_feature_extractor_class = old_model_info["model_patterns"].feature_extractor_class
|
1643 |
+
old_processor_class = old_model_info["model_patterns"].processor_class
|
1644 |
+
old_frameworks = old_model_info["frameworks"]
|
1645 |
+
|
1646 |
+
old_checkpoint = None
|
1647 |
+
if len(old_model_info["model_patterns"].checkpoint) == 0:
|
1648 |
+
old_checkpoint = get_user_field(
|
1649 |
+
"We couldn't find the name of the base checkpoint for that model, please enter it here."
|
1650 |
+
)
|
1651 |
+
|
1652 |
+
model_name = get_user_field(
|
1653 |
+
"What is the name (with no special casing) for your new model in the paper (e.g. RoBERTa)? "
|
1654 |
+
)
|
1655 |
+
default_patterns = ModelPatterns(model_name, model_name)
|
1656 |
+
|
1657 |
+
model_type = get_user_field(
|
1658 |
+
"What identifier would you like to use for the `model_type` of this model? ",
|
1659 |
+
default_value=default_patterns.model_type,
|
1660 |
+
)
|
1661 |
+
model_lower_cased = get_user_field(
|
1662 |
+
"What lowercase name would you like to use for the module (folder) of this model? ",
|
1663 |
+
default_value=default_patterns.model_lower_cased,
|
1664 |
+
)
|
1665 |
+
model_camel_cased = get_user_field(
|
1666 |
+
"What prefix (camel-cased) would you like to use for the model classes of this model (e.g. Roberta)? ",
|
1667 |
+
default_value=default_patterns.model_camel_cased,
|
1668 |
+
)
|
1669 |
+
model_upper_cased = get_user_field(
|
1670 |
+
"What prefix (upper-cased) would you like to use for the constants relative to this model? ",
|
1671 |
+
default_value=default_patterns.model_upper_cased,
|
1672 |
+
)
|
1673 |
+
config_class = get_user_field(
|
1674 |
+
"What will be the name of the config class for this model? ", default_value=f"{model_camel_cased}Config"
|
1675 |
+
)
|
1676 |
+
checkpoint = get_user_field(
|
1677 |
+
"Please give a checkpoint identifier (on the model Hub) for this new model (e.g. facebook/roberta-base): "
|
1678 |
+
)
|
1679 |
+
|
1680 |
+
old_processing_classes = [
|
1681 |
+
c
|
1682 |
+
for c in [old_image_processor_class, old_feature_extractor_class, old_tokenizer_class, old_processor_class]
|
1683 |
+
if c is not None
|
1684 |
+
]
|
1685 |
+
old_processing_classes = ", ".join(old_processing_classes)
|
1686 |
+
keep_processing = get_user_field(
|
1687 |
+
f"Will your new model use the same processing class as {old_model_type} ({old_processing_classes}) (yes/no)? ",
|
1688 |
+
convert_to=convert_to_bool,
|
1689 |
+
fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
|
1690 |
+
)
|
1691 |
+
if keep_processing:
|
1692 |
+
image_processor_class = old_image_processor_class
|
1693 |
+
feature_extractor_class = old_feature_extractor_class
|
1694 |
+
processor_class = old_processor_class
|
1695 |
+
tokenizer_class = old_tokenizer_class
|
1696 |
+
else:
|
1697 |
+
if old_tokenizer_class is not None:
|
1698 |
+
tokenizer_class = get_user_field(
|
1699 |
+
"What will be the name of the tokenizer class for this model? ",
|
1700 |
+
default_value=f"{model_camel_cased}Tokenizer",
|
1701 |
+
)
|
1702 |
+
else:
|
1703 |
+
tokenizer_class = None
|
1704 |
+
if old_image_processor_class is not None:
|
1705 |
+
image_processor_class = get_user_field(
|
1706 |
+
"What will be the name of the image processor class for this model? ",
|
1707 |
+
default_value=f"{model_camel_cased}ImageProcessor",
|
1708 |
+
)
|
1709 |
+
else:
|
1710 |
+
image_processor_class = None
|
1711 |
+
if old_feature_extractor_class is not None:
|
1712 |
+
feature_extractor_class = get_user_field(
|
1713 |
+
"What will be the name of the feature extractor class for this model? ",
|
1714 |
+
default_value=f"{model_camel_cased}FeatureExtractor",
|
1715 |
+
)
|
1716 |
+
else:
|
1717 |
+
feature_extractor_class = None
|
1718 |
+
if old_processor_class is not None:
|
1719 |
+
processor_class = get_user_field(
|
1720 |
+
"What will be the name of the processor class for this model? ",
|
1721 |
+
default_value=f"{model_camel_cased}Processor",
|
1722 |
+
)
|
1723 |
+
else:
|
1724 |
+
processor_class = None
|
1725 |
+
|
1726 |
+
model_patterns = ModelPatterns(
|
1727 |
+
model_name,
|
1728 |
+
checkpoint,
|
1729 |
+
model_type=model_type,
|
1730 |
+
model_lower_cased=model_lower_cased,
|
1731 |
+
model_camel_cased=model_camel_cased,
|
1732 |
+
model_upper_cased=model_upper_cased,
|
1733 |
+
config_class=config_class,
|
1734 |
+
tokenizer_class=tokenizer_class,
|
1735 |
+
image_processor_class=image_processor_class,
|
1736 |
+
feature_extractor_class=feature_extractor_class,
|
1737 |
+
processor_class=processor_class,
|
1738 |
+
)
|
1739 |
+
|
1740 |
+
add_copied_from = get_user_field(
|
1741 |
+
"Should we add # Copied from statements when creating the new modeling file (yes/no)? ",
|
1742 |
+
convert_to=convert_to_bool,
|
1743 |
+
default_value="yes",
|
1744 |
+
fallback_message="Please answer yes/no, y/n, true/false or 1/0.",
|
1745 |
+
)
|
1746 |
+
|
1747 |
+
all_frameworks = get_user_field(
|
1748 |
+
"Should we add a version of your new model in all the frameworks implemented by"
|
1749 |
+
f" {old_model_type} ({old_frameworks}) (yes/no)? ",
|
1750 |
+
convert_to=convert_to_bool,
|
1751 |
+
default_value="yes",
|
1752 |
+
fallback_message="Please answer yes/no, y/n, true/false or 1/0.",
|
1753 |
+
)
|
1754 |
+
if all_frameworks:
|
1755 |
+
frameworks = None
|
1756 |
+
else:
|
1757 |
+
frameworks = get_user_field(
|
1758 |
+
"Please enter the list of framworks you want (pt, tf, flax) separated by spaces",
|
1759 |
+
is_valid_answer=lambda x: all(p in ["pt", "tf", "flax"] for p in x.split(" ")),
|
1760 |
+
)
|
1761 |
+
frameworks = list(set(frameworks.split(" ")))
|
1762 |
+
|
1763 |
+
return (old_model_type, model_patterns, add_copied_from, frameworks, old_checkpoint)
|
transformers_4_35_0/commands/convert.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from argparse import ArgumentParser, Namespace
|
16 |
+
|
17 |
+
from ..utils import logging
|
18 |
+
from . import BaseTransformersCLICommand
|
19 |
+
|
20 |
+
|
21 |
+
def convert_command_factory(args: Namespace):
|
22 |
+
"""
|
23 |
+
Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint.
|
24 |
+
|
25 |
+
Returns: ServeCommand
|
26 |
+
"""
|
27 |
+
return ConvertCommand(
|
28 |
+
args.model_type, args.tf_checkpoint, args.pytorch_dump_output, args.config, args.finetuning_task_name
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
IMPORT_ERROR_MESSAGE = """
|
33 |
+
transformers can only be used from the commandline to convert TensorFlow models in PyTorch, In that case, it requires
|
34 |
+
TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions.
|
35 |
+
"""
|
36 |
+
|
37 |
+
|
38 |
+
class ConvertCommand(BaseTransformersCLICommand):
|
39 |
+
@staticmethod
|
40 |
+
def register_subcommand(parser: ArgumentParser):
|
41 |
+
"""
|
42 |
+
Register this command to argparse so it's available for the transformer-cli
|
43 |
+
|
44 |
+
Args:
|
45 |
+
parser: Root parser to register command-specific arguments
|
46 |
+
"""
|
47 |
+
train_parser = parser.add_parser(
|
48 |
+
"convert",
|
49 |
+
help="CLI tool to run convert model from original author checkpoints to Transformers PyTorch checkpoints.",
|
50 |
+
)
|
51 |
+
train_parser.add_argument("--model_type", type=str, required=True, help="Model's type.")
|
52 |
+
train_parser.add_argument(
|
53 |
+
"--tf_checkpoint", type=str, required=True, help="TensorFlow checkpoint path or folder."
|
54 |
+
)
|
55 |
+
train_parser.add_argument(
|
56 |
+
"--pytorch_dump_output", type=str, required=True, help="Path to the PyTorch saved model output."
|
57 |
+
)
|
58 |
+
train_parser.add_argument("--config", type=str, default="", help="Configuration file path or folder.")
|
59 |
+
train_parser.add_argument(
|
60 |
+
"--finetuning_task_name",
|
61 |
+
type=str,
|
62 |
+
default=None,
|
63 |
+
help="Optional fine-tuning task name if the TF model was a finetuned model.",
|
64 |
+
)
|
65 |
+
train_parser.set_defaults(func=convert_command_factory)
|
66 |
+
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
model_type: str,
|
70 |
+
tf_checkpoint: str,
|
71 |
+
pytorch_dump_output: str,
|
72 |
+
config: str,
|
73 |
+
finetuning_task_name: str,
|
74 |
+
*args,
|
75 |
+
):
|
76 |
+
self._logger = logging.get_logger("transformers-cli/converting")
|
77 |
+
|
78 |
+
self._logger.info(f"Loading model {model_type}")
|
79 |
+
self._model_type = model_type
|
80 |
+
self._tf_checkpoint = tf_checkpoint
|
81 |
+
self._pytorch_dump_output = pytorch_dump_output
|
82 |
+
self._config = config
|
83 |
+
self._finetuning_task_name = finetuning_task_name
|
84 |
+
|
85 |
+
def run(self):
|
86 |
+
if self._model_type == "albert":
|
87 |
+
try:
|
88 |
+
from ..models.albert.convert_albert_original_tf_checkpoint_to_pytorch import (
|
89 |
+
convert_tf_checkpoint_to_pytorch,
|
90 |
+
)
|
91 |
+
except ImportError:
|
92 |
+
raise ImportError(IMPORT_ERROR_MESSAGE)
|
93 |
+
|
94 |
+
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
95 |
+
elif self._model_type == "bert":
|
96 |
+
try:
|
97 |
+
from ..models.bert.convert_bert_original_tf_checkpoint_to_pytorch import (
|
98 |
+
convert_tf_checkpoint_to_pytorch,
|
99 |
+
)
|
100 |
+
except ImportError:
|
101 |
+
raise ImportError(IMPORT_ERROR_MESSAGE)
|
102 |
+
|
103 |
+
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
104 |
+
elif self._model_type == "funnel":
|
105 |
+
try:
|
106 |
+
from ..models.funnel.convert_funnel_original_tf_checkpoint_to_pytorch import (
|
107 |
+
convert_tf_checkpoint_to_pytorch,
|
108 |
+
)
|
109 |
+
except ImportError:
|
110 |
+
raise ImportError(IMPORT_ERROR_MESSAGE)
|
111 |
+
|
112 |
+
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
113 |
+
elif self._model_type == "t5":
|
114 |
+
try:
|
115 |
+
from ..models.t5.convert_t5_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
|
116 |
+
except ImportError:
|
117 |
+
raise ImportError(IMPORT_ERROR_MESSAGE)
|
118 |
+
|
119 |
+
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
120 |
+
elif self._model_type == "gpt":
|
121 |
+
from ..models.openai.convert_openai_original_tf_checkpoint_to_pytorch import (
|
122 |
+
convert_openai_checkpoint_to_pytorch,
|
123 |
+
)
|
124 |
+
|
125 |
+
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
126 |
+
elif self._model_type == "transfo_xl":
|
127 |
+
try:
|
128 |
+
from ..models.transfo_xl.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
|
129 |
+
convert_transfo_xl_checkpoint_to_pytorch,
|
130 |
+
)
|
131 |
+
except ImportError:
|
132 |
+
raise ImportError(IMPORT_ERROR_MESSAGE)
|
133 |
+
|
134 |
+
if "ckpt" in self._tf_checkpoint.lower():
|
135 |
+
TF_CHECKPOINT = self._tf_checkpoint
|
136 |
+
TF_DATASET_FILE = ""
|
137 |
+
else:
|
138 |
+
TF_DATASET_FILE = self._tf_checkpoint
|
139 |
+
TF_CHECKPOINT = ""
|
140 |
+
convert_transfo_xl_checkpoint_to_pytorch(
|
141 |
+
TF_CHECKPOINT, self._config, self._pytorch_dump_output, TF_DATASET_FILE
|
142 |
+
)
|
143 |
+
elif self._model_type == "gpt2":
|
144 |
+
try:
|
145 |
+
from ..models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import (
|
146 |
+
convert_gpt2_checkpoint_to_pytorch,
|
147 |
+
)
|
148 |
+
except ImportError:
|
149 |
+
raise ImportError(IMPORT_ERROR_MESSAGE)
|
150 |
+
|
151 |
+
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
152 |
+
elif self._model_type == "xlnet":
|
153 |
+
try:
|
154 |
+
from ..models.xlnet.convert_xlnet_original_tf_checkpoint_to_pytorch import (
|
155 |
+
convert_xlnet_checkpoint_to_pytorch,
|
156 |
+
)
|
157 |
+
except ImportError:
|
158 |
+
raise ImportError(IMPORT_ERROR_MESSAGE)
|
159 |
+
|
160 |
+
convert_xlnet_checkpoint_to_pytorch(
|
161 |
+
self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
|
162 |
+
)
|
163 |
+
elif self._model_type == "xlm":
|
164 |
+
from ..models.xlm.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
|
165 |
+
convert_xlm_checkpoint_to_pytorch,
|
166 |
+
)
|
167 |
+
|
168 |
+
convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
|
169 |
+
elif self._model_type == "lxmert":
|
170 |
+
from ..models.lxmert.convert_lxmert_original_tf_checkpoint_to_pytorch import (
|
171 |
+
convert_lxmert_checkpoint_to_pytorch,
|
172 |
+
)
|
173 |
+
|
174 |
+
convert_lxmert_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
|
175 |
+
elif self._model_type == "rembert":
|
176 |
+
from ..models.rembert.convert_rembert_tf_checkpoint_to_pytorch import (
|
177 |
+
convert_rembert_tf_checkpoint_to_pytorch,
|
178 |
+
)
|
179 |
+
|
180 |
+
convert_rembert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
181 |
+
else:
|
182 |
+
raise ValueError(
|
183 |
+
"--model_type should be selected in the list [bert, gpt, gpt2, t5, transfo_xl, xlnet, xlm, lxmert]"
|
184 |
+
)
|
transformers_4_35_0/commands/download.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from argparse import ArgumentParser
|
16 |
+
|
17 |
+
from . import BaseTransformersCLICommand
|
18 |
+
|
19 |
+
|
20 |
+
def download_command_factory(args):
|
21 |
+
return DownloadCommand(args.model, args.cache_dir, args.force, args.trust_remote_code)
|
22 |
+
|
23 |
+
|
24 |
+
class DownloadCommand(BaseTransformersCLICommand):
|
25 |
+
@staticmethod
|
26 |
+
def register_subcommand(parser: ArgumentParser):
|
27 |
+
download_parser = parser.add_parser("download")
|
28 |
+
download_parser.add_argument(
|
29 |
+
"--cache-dir", type=str, default=None, help="Path to location to store the models"
|
30 |
+
)
|
31 |
+
download_parser.add_argument(
|
32 |
+
"--force", action="store_true", help="Force the model to be download even if already in cache-dir"
|
33 |
+
)
|
34 |
+
download_parser.add_argument(
|
35 |
+
"--trust-remote-code",
|
36 |
+
action="store_true",
|
37 |
+
help="Whether or not to allow for custom models defined on the Hub in their own modeling files. Use only if you've reviewed the code as it will execute on your local machine",
|
38 |
+
)
|
39 |
+
download_parser.add_argument("model", type=str, help="Name of the model to download")
|
40 |
+
download_parser.set_defaults(func=download_command_factory)
|
41 |
+
|
42 |
+
def __init__(self, model: str, cache: str, force: bool, trust_remote_code: bool):
|
43 |
+
self._model = model
|
44 |
+
self._cache = cache
|
45 |
+
self._force = force
|
46 |
+
self._trust_remote_code = trust_remote_code
|
47 |
+
|
48 |
+
def run(self):
|
49 |
+
from ..models.auto import AutoModel, AutoTokenizer
|
50 |
+
|
51 |
+
AutoModel.from_pretrained(
|
52 |
+
self._model, cache_dir=self._cache, force_download=self._force, trust_remote_code=self._trust_remote_code
|
53 |
+
)
|
54 |
+
AutoTokenizer.from_pretrained(
|
55 |
+
self._model, cache_dir=self._cache, force_download=self._force, trust_remote_code=self._trust_remote_code
|
56 |
+
)
|
transformers_4_35_0/commands/env.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import importlib.util
|
16 |
+
import os
|
17 |
+
import platform
|
18 |
+
from argparse import ArgumentParser
|
19 |
+
|
20 |
+
import huggingface_hub
|
21 |
+
|
22 |
+
from .. import __version__ as version
|
23 |
+
from ..utils import (
|
24 |
+
is_accelerate_available,
|
25 |
+
is_flax_available,
|
26 |
+
is_safetensors_available,
|
27 |
+
is_tf_available,
|
28 |
+
is_torch_available,
|
29 |
+
)
|
30 |
+
from . import BaseTransformersCLICommand
|
31 |
+
|
32 |
+
|
33 |
+
def info_command_factory(_):
|
34 |
+
return EnvironmentCommand()
|
35 |
+
|
36 |
+
|
37 |
+
def download_command_factory(args):
|
38 |
+
return EnvironmentCommand(args.accelerate_config_file)
|
39 |
+
|
40 |
+
|
41 |
+
class EnvironmentCommand(BaseTransformersCLICommand):
|
42 |
+
@staticmethod
|
43 |
+
def register_subcommand(parser: ArgumentParser):
|
44 |
+
download_parser = parser.add_parser("env")
|
45 |
+
download_parser.set_defaults(func=info_command_factory)
|
46 |
+
download_parser.add_argument(
|
47 |
+
"--accelerate-config_file",
|
48 |
+
default=None,
|
49 |
+
help="The accelerate config file to use for the default values in the launching script.",
|
50 |
+
)
|
51 |
+
download_parser.set_defaults(func=download_command_factory)
|
52 |
+
|
53 |
+
def __init__(self, accelerate_config_file, *args) -> None:
|
54 |
+
self._accelerate_config_file = accelerate_config_file
|
55 |
+
|
56 |
+
def run(self):
|
57 |
+
safetensors_version = "not installed"
|
58 |
+
if is_safetensors_available():
|
59 |
+
import safetensors
|
60 |
+
|
61 |
+
safetensors_version = safetensors.__version__
|
62 |
+
elif importlib.util.find_spec("safetensors") is not None:
|
63 |
+
import safetensors
|
64 |
+
|
65 |
+
safetensors_version = f"{safetensors.__version__} but is ignored because of PyTorch version too old."
|
66 |
+
|
67 |
+
accelerate_version = "not installed"
|
68 |
+
accelerate_config = accelerate_config_str = "not found"
|
69 |
+
if is_accelerate_available():
|
70 |
+
import accelerate
|
71 |
+
from accelerate.commands.config import default_config_file, load_config_from_file
|
72 |
+
|
73 |
+
accelerate_version = accelerate.__version__
|
74 |
+
# Get the default from the config file.
|
75 |
+
if self._accelerate_config_file is not None or os.path.isfile(default_config_file):
|
76 |
+
accelerate_config = load_config_from_file(self._accelerate_config_file).to_dict()
|
77 |
+
|
78 |
+
accelerate_config_str = (
|
79 |
+
"\n".join([f"\t- {prop}: {val}" for prop, val in accelerate_config.items()])
|
80 |
+
if isinstance(accelerate_config, dict)
|
81 |
+
else f"\t{accelerate_config}"
|
82 |
+
)
|
83 |
+
|
84 |
+
pt_version = "not installed"
|
85 |
+
pt_cuda_available = "NA"
|
86 |
+
if is_torch_available():
|
87 |
+
import torch
|
88 |
+
|
89 |
+
pt_version = torch.__version__
|
90 |
+
pt_cuda_available = torch.cuda.is_available()
|
91 |
+
|
92 |
+
tf_version = "not installed"
|
93 |
+
tf_cuda_available = "NA"
|
94 |
+
if is_tf_available():
|
95 |
+
import tensorflow as tf
|
96 |
+
|
97 |
+
tf_version = tf.__version__
|
98 |
+
try:
|
99 |
+
# deprecated in v2.1
|
100 |
+
tf_cuda_available = tf.test.is_gpu_available()
|
101 |
+
except AttributeError:
|
102 |
+
# returns list of devices, convert to bool
|
103 |
+
tf_cuda_available = bool(tf.config.list_physical_devices("GPU"))
|
104 |
+
|
105 |
+
flax_version = "not installed"
|
106 |
+
jax_version = "not installed"
|
107 |
+
jaxlib_version = "not installed"
|
108 |
+
jax_backend = "NA"
|
109 |
+
if is_flax_available():
|
110 |
+
import flax
|
111 |
+
import jax
|
112 |
+
import jaxlib
|
113 |
+
|
114 |
+
flax_version = flax.__version__
|
115 |
+
jax_version = jax.__version__
|
116 |
+
jaxlib_version = jaxlib.__version__
|
117 |
+
jax_backend = jax.lib.xla_bridge.get_backend().platform
|
118 |
+
|
119 |
+
info = {
|
120 |
+
"`transformers` version": version,
|
121 |
+
"Platform": platform.platform(),
|
122 |
+
"Python version": platform.python_version(),
|
123 |
+
"Huggingface_hub version": huggingface_hub.__version__,
|
124 |
+
"Safetensors version": f"{safetensors_version}",
|
125 |
+
"Accelerate version": f"{accelerate_version}",
|
126 |
+
"Accelerate config": f"{accelerate_config_str}",
|
127 |
+
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
|
128 |
+
"Tensorflow version (GPU?)": f"{tf_version} ({tf_cuda_available})",
|
129 |
+
"Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})",
|
130 |
+
"Jax version": f"{jax_version}",
|
131 |
+
"JaxLib version": f"{jaxlib_version}",
|
132 |
+
"Using GPU in script?": "<fill in>",
|
133 |
+
"Using distributed or parallel set-up in script?": "<fill in>",
|
134 |
+
}
|
135 |
+
|
136 |
+
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
|
137 |
+
print(self.format_dict(info))
|
138 |
+
|
139 |
+
return info
|
140 |
+
|
141 |
+
@staticmethod
|
142 |
+
def format_dict(d):
|
143 |
+
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
|
transformers_4_35_0/commands/lfs.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of a custom transfer agent for the transfer type "multipart" for git-lfs.
|
3 |
+
|
4 |
+
Inspired by: github.com/cbartz/git-lfs-swift-transfer-agent/blob/master/git_lfs_swift_transfer.py
|
5 |
+
|
6 |
+
Spec is: github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md
|
7 |
+
|
8 |
+
|
9 |
+
To launch debugger while developing:
|
10 |
+
|
11 |
+
``` [lfs "customtransfer.multipart"]
|
12 |
+
path = /path/to/transformers/.env/bin/python args = -m debugpy --listen 5678 --wait-for-client
|
13 |
+
/path/to/transformers/src/transformers/commands/transformers_cli.py lfs-multipart-upload ```"""
|
14 |
+
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
import subprocess
|
18 |
+
import sys
|
19 |
+
import warnings
|
20 |
+
from argparse import ArgumentParser
|
21 |
+
from contextlib import AbstractContextManager
|
22 |
+
from typing import Dict, List, Optional
|
23 |
+
|
24 |
+
import requests
|
25 |
+
|
26 |
+
from ..utils import logging
|
27 |
+
from . import BaseTransformersCLICommand
|
28 |
+
|
29 |
+
|
30 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
31 |
+
|
32 |
+
|
33 |
+
LFS_MULTIPART_UPLOAD_COMMAND = "lfs-multipart-upload"
|
34 |
+
|
35 |
+
|
36 |
+
class LfsCommands(BaseTransformersCLICommand):
|
37 |
+
"""
|
38 |
+
Implementation of a custom transfer agent for the transfer type "multipart" for git-lfs. This lets users upload
|
39 |
+
large files >5GB 🔥. Spec for LFS custom transfer agent is:
|
40 |
+
https://github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md
|
41 |
+
|
42 |
+
This introduces two commands to the CLI:
|
43 |
+
|
44 |
+
1. $ transformers-cli lfs-enable-largefiles
|
45 |
+
|
46 |
+
This should be executed once for each model repo that contains a model file >5GB. It's documented in the error
|
47 |
+
message you get if you just try to git push a 5GB file without having enabled it before.
|
48 |
+
|
49 |
+
2. $ transformers-cli lfs-multipart-upload
|
50 |
+
|
51 |
+
This command is called by lfs directly and is not meant to be called by the user.
|
52 |
+
"""
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def register_subcommand(parser: ArgumentParser):
|
56 |
+
enable_parser = parser.add_parser(
|
57 |
+
"lfs-enable-largefiles",
|
58 |
+
help=(
|
59 |
+
"Deprecated: use `huggingface-cli` instead. Configure your repository to enable upload of files > 5GB."
|
60 |
+
),
|
61 |
+
)
|
62 |
+
enable_parser.add_argument("path", type=str, help="Local path to repository you want to configure.")
|
63 |
+
enable_parser.set_defaults(func=lambda args: LfsEnableCommand(args))
|
64 |
+
|
65 |
+
upload_parser = parser.add_parser(
|
66 |
+
LFS_MULTIPART_UPLOAD_COMMAND,
|
67 |
+
help=(
|
68 |
+
"Deprecated: use `huggingface-cli` instead. "
|
69 |
+
"Command will get called by git-lfs, do not call it directly."
|
70 |
+
),
|
71 |
+
)
|
72 |
+
upload_parser.set_defaults(func=lambda args: LfsUploadCommand(args))
|
73 |
+
|
74 |
+
|
75 |
+
class LfsEnableCommand:
|
76 |
+
def __init__(self, args):
|
77 |
+
self.args = args
|
78 |
+
|
79 |
+
def run(self):
|
80 |
+
warnings.warn(
|
81 |
+
"Managing repositories through transformers-cli is deprecated. Please use `huggingface-cli` instead."
|
82 |
+
)
|
83 |
+
local_path = os.path.abspath(self.args.path)
|
84 |
+
if not os.path.isdir(local_path):
|
85 |
+
print("This does not look like a valid git repo.")
|
86 |
+
exit(1)
|
87 |
+
subprocess.run(
|
88 |
+
"git config lfs.customtransfer.multipart.path transformers-cli".split(), check=True, cwd=local_path
|
89 |
+
)
|
90 |
+
subprocess.run(
|
91 |
+
f"git config lfs.customtransfer.multipart.args {LFS_MULTIPART_UPLOAD_COMMAND}".split(),
|
92 |
+
check=True,
|
93 |
+
cwd=local_path,
|
94 |
+
)
|
95 |
+
print("Local repo set up for largefiles")
|
96 |
+
|
97 |
+
|
98 |
+
def write_msg(msg: Dict):
|
99 |
+
"""Write out the message in Line delimited JSON."""
|
100 |
+
msg = json.dumps(msg) + "\n"
|
101 |
+
sys.stdout.write(msg)
|
102 |
+
sys.stdout.flush()
|
103 |
+
|
104 |
+
|
105 |
+
def read_msg() -> Optional[Dict]:
|
106 |
+
"""Read Line delimited JSON from stdin."""
|
107 |
+
msg = json.loads(sys.stdin.readline().strip())
|
108 |
+
|
109 |
+
if "terminate" in (msg.get("type"), msg.get("event")):
|
110 |
+
# terminate message received
|
111 |
+
return None
|
112 |
+
|
113 |
+
if msg.get("event") not in ("download", "upload"):
|
114 |
+
logger.critical("Received unexpected message")
|
115 |
+
sys.exit(1)
|
116 |
+
|
117 |
+
return msg
|
118 |
+
|
119 |
+
|
120 |
+
class FileSlice(AbstractContextManager):
|
121 |
+
"""
|
122 |
+
File-like object that only reads a slice of a file
|
123 |
+
|
124 |
+
Inspired by stackoverflow.com/a/29838711/593036
|
125 |
+
"""
|
126 |
+
|
127 |
+
def __init__(self, filepath: str, seek_from: int, read_limit: int):
|
128 |
+
self.filepath = filepath
|
129 |
+
self.seek_from = seek_from
|
130 |
+
self.read_limit = read_limit
|
131 |
+
self.n_seen = 0
|
132 |
+
|
133 |
+
def __enter__(self):
|
134 |
+
self.f = open(self.filepath, "rb")
|
135 |
+
self.f.seek(self.seek_from)
|
136 |
+
return self
|
137 |
+
|
138 |
+
def __len__(self):
|
139 |
+
total_length = os.fstat(self.f.fileno()).st_size
|
140 |
+
return min(self.read_limit, total_length - self.seek_from)
|
141 |
+
|
142 |
+
def read(self, n=-1):
|
143 |
+
if self.n_seen >= self.read_limit:
|
144 |
+
return b""
|
145 |
+
remaining_amount = self.read_limit - self.n_seen
|
146 |
+
data = self.f.read(remaining_amount if n < 0 else min(n, remaining_amount))
|
147 |
+
self.n_seen += len(data)
|
148 |
+
return data
|
149 |
+
|
150 |
+
def __iter__(self):
|
151 |
+
yield self.read(n=4 * 1024 * 1024)
|
152 |
+
|
153 |
+
def __exit__(self, *args):
|
154 |
+
self.f.close()
|
155 |
+
|
156 |
+
|
157 |
+
class LfsUploadCommand:
|
158 |
+
def __init__(self, args):
|
159 |
+
self.args = args
|
160 |
+
|
161 |
+
def run(self):
|
162 |
+
# Immediately after invoking a custom transfer process, git-lfs
|
163 |
+
# sends initiation data to the process over stdin.
|
164 |
+
# This tells the process useful information about the configuration.
|
165 |
+
init_msg = json.loads(sys.stdin.readline().strip())
|
166 |
+
if not (init_msg.get("event") == "init" and init_msg.get("operation") == "upload"):
|
167 |
+
write_msg({"error": {"code": 32, "message": "Wrong lfs init operation"}})
|
168 |
+
sys.exit(1)
|
169 |
+
|
170 |
+
# The transfer process should use the information it needs from the
|
171 |
+
# initiation structure, and also perform any one-off setup tasks it
|
172 |
+
# needs to do. It should then respond on stdout with a simple empty
|
173 |
+
# confirmation structure, as follows:
|
174 |
+
write_msg({})
|
175 |
+
|
176 |
+
# After the initiation exchange, git-lfs will send any number of
|
177 |
+
# transfer requests to the stdin of the transfer process, in a serial sequence.
|
178 |
+
while True:
|
179 |
+
msg = read_msg()
|
180 |
+
if msg is None:
|
181 |
+
# When all transfers have been processed, git-lfs will send
|
182 |
+
# a terminate event to the stdin of the transfer process.
|
183 |
+
# On receiving this message the transfer process should
|
184 |
+
# clean up and terminate. No response is expected.
|
185 |
+
sys.exit(0)
|
186 |
+
|
187 |
+
oid = msg["oid"]
|
188 |
+
filepath = msg["path"]
|
189 |
+
completion_url = msg["action"]["href"]
|
190 |
+
header = msg["action"]["header"]
|
191 |
+
chunk_size = int(header.pop("chunk_size"))
|
192 |
+
presigned_urls: List[str] = list(header.values())
|
193 |
+
|
194 |
+
parts = []
|
195 |
+
for i, presigned_url in enumerate(presigned_urls):
|
196 |
+
with FileSlice(filepath, seek_from=i * chunk_size, read_limit=chunk_size) as data:
|
197 |
+
r = requests.put(presigned_url, data=data)
|
198 |
+
r.raise_for_status()
|
199 |
+
parts.append(
|
200 |
+
{
|
201 |
+
"etag": r.headers.get("etag"),
|
202 |
+
"partNumber": i + 1,
|
203 |
+
}
|
204 |
+
)
|
205 |
+
# In order to support progress reporting while data is uploading / downloading,
|
206 |
+
# the transfer process should post messages to stdout
|
207 |
+
write_msg(
|
208 |
+
{
|
209 |
+
"event": "progress",
|
210 |
+
"oid": oid,
|
211 |
+
"bytesSoFar": (i + 1) * chunk_size,
|
212 |
+
"bytesSinceLast": chunk_size,
|
213 |
+
}
|
214 |
+
)
|
215 |
+
# Not precise but that's ok.
|
216 |
+
|
217 |
+
r = requests.post(
|
218 |
+
completion_url,
|
219 |
+
json={
|
220 |
+
"oid": oid,
|
221 |
+
"parts": parts,
|
222 |
+
},
|
223 |
+
)
|
224 |
+
r.raise_for_status()
|
225 |
+
|
226 |
+
write_msg({"event": "complete", "oid": oid})
|
transformers_4_35_0/commands/pt_to_tf.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
import os
|
17 |
+
from argparse import ArgumentParser, Namespace
|
18 |
+
from importlib import import_module
|
19 |
+
|
20 |
+
import huggingface_hub
|
21 |
+
import numpy as np
|
22 |
+
from packaging import version
|
23 |
+
|
24 |
+
from .. import (
|
25 |
+
FEATURE_EXTRACTOR_MAPPING,
|
26 |
+
IMAGE_PROCESSOR_MAPPING,
|
27 |
+
PROCESSOR_MAPPING,
|
28 |
+
TOKENIZER_MAPPING,
|
29 |
+
AutoConfig,
|
30 |
+
AutoFeatureExtractor,
|
31 |
+
AutoImageProcessor,
|
32 |
+
AutoProcessor,
|
33 |
+
AutoTokenizer,
|
34 |
+
is_datasets_available,
|
35 |
+
is_tf_available,
|
36 |
+
is_torch_available,
|
37 |
+
)
|
38 |
+
from ..utils import TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
|
39 |
+
from . import BaseTransformersCLICommand
|
40 |
+
|
41 |
+
|
42 |
+
if is_tf_available():
|
43 |
+
import tensorflow as tf
|
44 |
+
|
45 |
+
tf.config.experimental.enable_tensor_float_32_execution(False)
|
46 |
+
|
47 |
+
if is_torch_available():
|
48 |
+
import torch
|
49 |
+
|
50 |
+
if is_datasets_available():
|
51 |
+
from datasets import load_dataset
|
52 |
+
|
53 |
+
|
54 |
+
MAX_ERROR = 5e-5 # larger error tolerance than in our internal tests, to avoid flaky user-facing errors
|
55 |
+
|
56 |
+
|
57 |
+
def convert_command_factory(args: Namespace):
|
58 |
+
"""
|
59 |
+
Factory function used to convert a model PyTorch checkpoint in a TensorFlow 2 checkpoint.
|
60 |
+
|
61 |
+
Returns: ServeCommand
|
62 |
+
"""
|
63 |
+
return PTtoTFCommand(
|
64 |
+
args.model_name,
|
65 |
+
args.local_dir,
|
66 |
+
args.max_error,
|
67 |
+
args.new_weights,
|
68 |
+
args.no_pr,
|
69 |
+
args.push,
|
70 |
+
args.extra_commit_description,
|
71 |
+
args.override_model_class,
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
class PTtoTFCommand(BaseTransformersCLICommand):
|
76 |
+
@staticmethod
|
77 |
+
def register_subcommand(parser: ArgumentParser):
|
78 |
+
"""
|
79 |
+
Register this command to argparse so it's available for the transformer-cli
|
80 |
+
|
81 |
+
Args:
|
82 |
+
parser: Root parser to register command-specific arguments
|
83 |
+
"""
|
84 |
+
train_parser = parser.add_parser(
|
85 |
+
"pt-to-tf",
|
86 |
+
help=(
|
87 |
+
"CLI tool to run convert a transformers model from a PyTorch checkpoint to a TensorFlow checkpoint."
|
88 |
+
" Can also be used to validate existing weights without opening PRs, with --no-pr."
|
89 |
+
),
|
90 |
+
)
|
91 |
+
train_parser.add_argument(
|
92 |
+
"--model-name",
|
93 |
+
type=str,
|
94 |
+
required=True,
|
95 |
+
help="The model name, including owner/organization, as seen on the hub.",
|
96 |
+
)
|
97 |
+
train_parser.add_argument(
|
98 |
+
"--local-dir",
|
99 |
+
type=str,
|
100 |
+
default="",
|
101 |
+
help="Optional local directory of the model repository. Defaults to /tmp/{model_name}",
|
102 |
+
)
|
103 |
+
train_parser.add_argument(
|
104 |
+
"--max-error",
|
105 |
+
type=float,
|
106 |
+
default=MAX_ERROR,
|
107 |
+
help=(
|
108 |
+
f"Maximum error tolerance. Defaults to {MAX_ERROR}. This flag should be avoided, use at your own risk."
|
109 |
+
),
|
110 |
+
)
|
111 |
+
train_parser.add_argument(
|
112 |
+
"--new-weights",
|
113 |
+
action="store_true",
|
114 |
+
help="Optional flag to create new TensorFlow weights, even if they already exist.",
|
115 |
+
)
|
116 |
+
train_parser.add_argument(
|
117 |
+
"--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights."
|
118 |
+
)
|
119 |
+
train_parser.add_argument(
|
120 |
+
"--push",
|
121 |
+
action="store_true",
|
122 |
+
help="Optional flag to push the weights directly to `main` (requires permissions)",
|
123 |
+
)
|
124 |
+
train_parser.add_argument(
|
125 |
+
"--extra-commit-description",
|
126 |
+
type=str,
|
127 |
+
default="",
|
128 |
+
help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).",
|
129 |
+
)
|
130 |
+
train_parser.add_argument(
|
131 |
+
"--override-model-class",
|
132 |
+
type=str,
|
133 |
+
default=None,
|
134 |
+
help="If you think you know better than the auto-detector, you can specify the model class here. "
|
135 |
+
"Can be either an AutoModel class or a specific model class like BertForSequenceClassification.",
|
136 |
+
)
|
137 |
+
train_parser.set_defaults(func=convert_command_factory)
|
138 |
+
|
139 |
+
@staticmethod
|
140 |
+
def find_pt_tf_differences(pt_outputs, tf_outputs):
|
141 |
+
"""
|
142 |
+
Compares the TensorFlow and PyTorch outputs, returning a dictionary with all tensor differences.
|
143 |
+
"""
|
144 |
+
# 1. All output attributes must be the same
|
145 |
+
pt_out_attrs = set(pt_outputs.keys())
|
146 |
+
tf_out_attrs = set(tf_outputs.keys())
|
147 |
+
if pt_out_attrs != tf_out_attrs:
|
148 |
+
raise ValueError(
|
149 |
+
f"The model outputs have different attributes, aborting. (Pytorch: {pt_out_attrs}, TensorFlow:"
|
150 |
+
f" {tf_out_attrs})"
|
151 |
+
)
|
152 |
+
|
153 |
+
# 2. For each output attribute, computes the difference
|
154 |
+
def _find_pt_tf_differences(pt_out, tf_out, differences, attr_name=""):
|
155 |
+
# If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in
|
156 |
+
# recursivelly, keeping the name of the attribute.
|
157 |
+
if isinstance(pt_out, torch.Tensor):
|
158 |
+
tensor_difference = np.max(np.abs(pt_out.numpy() - tf_out.numpy()))
|
159 |
+
differences[attr_name] = tensor_difference
|
160 |
+
else:
|
161 |
+
root_name = attr_name
|
162 |
+
for i, pt_item in enumerate(pt_out):
|
163 |
+
# If it is a named attribute, we keep the name. Otherwise, just its index.
|
164 |
+
if isinstance(pt_item, str):
|
165 |
+
branch_name = root_name + pt_item
|
166 |
+
tf_item = tf_out[pt_item]
|
167 |
+
pt_item = pt_out[pt_item]
|
168 |
+
else:
|
169 |
+
branch_name = root_name + f"[{i}]"
|
170 |
+
tf_item = tf_out[i]
|
171 |
+
differences = _find_pt_tf_differences(pt_item, tf_item, differences, branch_name)
|
172 |
+
|
173 |
+
return differences
|
174 |
+
|
175 |
+
return _find_pt_tf_differences(pt_outputs, tf_outputs, {})
|
176 |
+
|
177 |
+
def __init__(
|
178 |
+
self,
|
179 |
+
model_name: str,
|
180 |
+
local_dir: str,
|
181 |
+
max_error: float,
|
182 |
+
new_weights: bool,
|
183 |
+
no_pr: bool,
|
184 |
+
push: bool,
|
185 |
+
extra_commit_description: str,
|
186 |
+
override_model_class: str,
|
187 |
+
*args,
|
188 |
+
):
|
189 |
+
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
|
190 |
+
self._model_name = model_name
|
191 |
+
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
|
192 |
+
self._max_error = max_error
|
193 |
+
self._new_weights = new_weights
|
194 |
+
self._no_pr = no_pr
|
195 |
+
self._push = push
|
196 |
+
self._extra_commit_description = extra_commit_description
|
197 |
+
self._override_model_class = override_model_class
|
198 |
+
|
199 |
+
def get_inputs(self, pt_model, tf_dummy_inputs, config):
|
200 |
+
"""
|
201 |
+
Returns the right inputs for the model, based on its signature.
|
202 |
+
"""
|
203 |
+
|
204 |
+
def _get_audio_input():
|
205 |
+
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
206 |
+
speech_samples = ds.sort("id").select(range(2))[:2]["audio"]
|
207 |
+
raw_samples = [x["array"] for x in speech_samples]
|
208 |
+
return raw_samples
|
209 |
+
|
210 |
+
model_config_class = type(pt_model.config)
|
211 |
+
if model_config_class in PROCESSOR_MAPPING:
|
212 |
+
processor = AutoProcessor.from_pretrained(self._local_dir)
|
213 |
+
if model_config_class in TOKENIZER_MAPPING and processor.tokenizer.pad_token is None:
|
214 |
+
processor.tokenizer.pad_token = processor.tokenizer.eos_token
|
215 |
+
elif model_config_class in IMAGE_PROCESSOR_MAPPING:
|
216 |
+
processor = AutoImageProcessor.from_pretrained(self._local_dir)
|
217 |
+
elif model_config_class in FEATURE_EXTRACTOR_MAPPING:
|
218 |
+
processor = AutoFeatureExtractor.from_pretrained(self._local_dir)
|
219 |
+
elif model_config_class in TOKENIZER_MAPPING:
|
220 |
+
processor = AutoTokenizer.from_pretrained(self._local_dir)
|
221 |
+
if processor.pad_token is None:
|
222 |
+
processor.pad_token = processor.eos_token
|
223 |
+
else:
|
224 |
+
raise ValueError(f"Unknown data processing type (model config type: {model_config_class})")
|
225 |
+
|
226 |
+
model_forward_signature = set(inspect.signature(pt_model.forward).parameters.keys())
|
227 |
+
processor_inputs = {}
|
228 |
+
if "input_ids" in model_forward_signature:
|
229 |
+
processor_inputs.update(
|
230 |
+
{
|
231 |
+
"text": ["Hi there!", "I am a batch with more than one row and different input lengths."],
|
232 |
+
"padding": True,
|
233 |
+
"truncation": True,
|
234 |
+
}
|
235 |
+
)
|
236 |
+
if "pixel_values" in model_forward_signature:
|
237 |
+
sample_images = load_dataset("cifar10", "plain_text", split="test")[:2]["img"]
|
238 |
+
processor_inputs.update({"images": sample_images})
|
239 |
+
if "input_features" in model_forward_signature:
|
240 |
+
feature_extractor_signature = inspect.signature(processor.feature_extractor).parameters
|
241 |
+
# Pad to the largest input length by default but take feature extractor default
|
242 |
+
# padding value if it exists e.g. "max_length" and is not False or None
|
243 |
+
if "padding" in feature_extractor_signature:
|
244 |
+
default_strategy = feature_extractor_signature["padding"].default
|
245 |
+
if default_strategy is not False and default_strategy is not None:
|
246 |
+
padding_strategy = default_strategy
|
247 |
+
else:
|
248 |
+
padding_strategy = True
|
249 |
+
else:
|
250 |
+
padding_strategy = True
|
251 |
+
processor_inputs.update({"audio": _get_audio_input(), "padding": padding_strategy})
|
252 |
+
if "input_values" in model_forward_signature: # Wav2Vec2 audio input
|
253 |
+
processor_inputs.update({"audio": _get_audio_input(), "padding": True})
|
254 |
+
pt_input = processor(**processor_inputs, return_tensors="pt")
|
255 |
+
tf_input = processor(**processor_inputs, return_tensors="tf")
|
256 |
+
|
257 |
+
# Extra input requirements, in addition to the input modality
|
258 |
+
if (
|
259 |
+
config.is_encoder_decoder
|
260 |
+
or (hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder"))
|
261 |
+
or "decoder_input_ids" in tf_dummy_inputs
|
262 |
+
):
|
263 |
+
decoder_input_ids = np.asarray([[1], [1]], dtype=int) * (pt_model.config.decoder_start_token_id or 0)
|
264 |
+
pt_input.update({"decoder_input_ids": torch.tensor(decoder_input_ids)})
|
265 |
+
tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})
|
266 |
+
|
267 |
+
return pt_input, tf_input
|
268 |
+
|
269 |
+
def run(self):
|
270 |
+
# hub version 0.9.0 introduced the possibility of programmatically opening PRs with normal write tokens.
|
271 |
+
if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
|
272 |
+
raise ImportError(
|
273 |
+
"The huggingface_hub version must be >= 0.9.0 to use this command. Please update your huggingface_hub"
|
274 |
+
" installation."
|
275 |
+
)
|
276 |
+
else:
|
277 |
+
from huggingface_hub import Repository, create_commit
|
278 |
+
from huggingface_hub._commit_api import CommitOperationAdd
|
279 |
+
|
280 |
+
# Fetch remote data
|
281 |
+
repo = Repository(local_dir=self._local_dir, clone_from=self._model_name)
|
282 |
+
|
283 |
+
# Load config and get the appropriate architecture -- the latter is needed to convert the head's weights
|
284 |
+
config = AutoConfig.from_pretrained(self._local_dir)
|
285 |
+
architectures = config.architectures
|
286 |
+
if self._override_model_class is not None:
|
287 |
+
if self._override_model_class.startswith("TF"):
|
288 |
+
architectures = [self._override_model_class[2:]]
|
289 |
+
else:
|
290 |
+
architectures = [self._override_model_class]
|
291 |
+
try:
|
292 |
+
pt_class = getattr(import_module("transformers"), architectures[0])
|
293 |
+
except AttributeError:
|
294 |
+
raise ValueError(f"Model class {self._override_model_class} not found in transformers.")
|
295 |
+
try:
|
296 |
+
tf_class = getattr(import_module("transformers"), "TF" + architectures[0])
|
297 |
+
except AttributeError:
|
298 |
+
raise ValueError(f"TF model class TF{self._override_model_class} not found in transformers.")
|
299 |
+
elif architectures is None: # No architecture defined -- use auto classes
|
300 |
+
pt_class = getattr(import_module("transformers"), "AutoModel")
|
301 |
+
tf_class = getattr(import_module("transformers"), "TFAutoModel")
|
302 |
+
self._logger.warning("No detected architecture, using AutoModel/TFAutoModel")
|
303 |
+
else: # Architecture defined -- use it
|
304 |
+
if len(architectures) > 1:
|
305 |
+
raise ValueError(f"More than one architecture was found, aborting. (architectures = {architectures})")
|
306 |
+
self._logger.warning(f"Detected architecture: {architectures[0]}")
|
307 |
+
pt_class = getattr(import_module("transformers"), architectures[0])
|
308 |
+
try:
|
309 |
+
tf_class = getattr(import_module("transformers"), "TF" + architectures[0])
|
310 |
+
except AttributeError:
|
311 |
+
raise AttributeError(f"The TensorFlow equivalent of {architectures[0]} doesn't exist in transformers.")
|
312 |
+
|
313 |
+
# Check the TF dummy inputs to see what keys we need in the forward pass
|
314 |
+
tf_from_pt_model = tf_class.from_config(config)
|
315 |
+
tf_dummy_inputs = tf_from_pt_model.dummy_inputs
|
316 |
+
|
317 |
+
del tf_from_pt_model # Try to keep only one model in memory at a time
|
318 |
+
|
319 |
+
# Load the model and get some basic inputs
|
320 |
+
pt_model = pt_class.from_pretrained(self._local_dir)
|
321 |
+
pt_model.eval()
|
322 |
+
|
323 |
+
pt_input, tf_input = self.get_inputs(pt_model, tf_dummy_inputs, config)
|
324 |
+
|
325 |
+
with torch.no_grad():
|
326 |
+
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
|
327 |
+
del pt_model # will no longer be used, and may have a large memory footprint
|
328 |
+
|
329 |
+
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)
|
330 |
+
tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True, training=False)
|
331 |
+
|
332 |
+
# Confirms that cross loading PT weights into TF worked.
|
333 |
+
crossload_differences = self.find_pt_tf_differences(pt_outputs, tf_from_pt_outputs)
|
334 |
+
output_differences = {k: v for k, v in crossload_differences.items() if "hidden" not in k}
|
335 |
+
hidden_differences = {k: v for k, v in crossload_differences.items() if "hidden" in k}
|
336 |
+
if len(output_differences) == 0 and architectures is not None:
|
337 |
+
raise ValueError(
|
338 |
+
f"Something went wrong -- the config file has architectures ({architectures}), but no model head"
|
339 |
+
" output was found. All outputs start with 'hidden'"
|
340 |
+
)
|
341 |
+
max_crossload_output_diff = max(output_differences.values()) if output_differences else 0.0
|
342 |
+
max_crossload_hidden_diff = max(hidden_differences.values())
|
343 |
+
if max_crossload_output_diff > self._max_error or max_crossload_hidden_diff > self._max_error:
|
344 |
+
raise ValueError(
|
345 |
+
"The cross-loaded TensorFlow model has different outputs, something went wrong!\n"
|
346 |
+
+ f"\nList of maximum output differences above the threshold ({self._max_error}):\n"
|
347 |
+
+ "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > self._max_error])
|
348 |
+
+ f"\n\nList of maximum hidden layer differences above the threshold ({self._max_error}):\n"
|
349 |
+
+ "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_error])
|
350 |
+
)
|
351 |
+
|
352 |
+
# Save the weights in a TF format (if needed) and confirms that the results are still good
|
353 |
+
tf_weights_path = os.path.join(self._local_dir, TF2_WEIGHTS_NAME)
|
354 |
+
tf_weights_index_path = os.path.join(self._local_dir, TF2_WEIGHTS_INDEX_NAME)
|
355 |
+
if (not os.path.exists(tf_weights_path) and not os.path.exists(tf_weights_index_path)) or self._new_weights:
|
356 |
+
tf_from_pt_model.save_pretrained(self._local_dir)
|
357 |
+
del tf_from_pt_model # will no longer be used, and may have a large memory footprint
|
358 |
+
|
359 |
+
tf_model = tf_class.from_pretrained(self._local_dir)
|
360 |
+
tf_outputs = tf_model(**tf_input, output_hidden_states=True)
|
361 |
+
|
362 |
+
conversion_differences = self.find_pt_tf_differences(pt_outputs, tf_outputs)
|
363 |
+
output_differences = {k: v for k, v in conversion_differences.items() if "hidden" not in k}
|
364 |
+
hidden_differences = {k: v for k, v in conversion_differences.items() if "hidden" in k}
|
365 |
+
if len(output_differences) == 0 and architectures is not None:
|
366 |
+
raise ValueError(
|
367 |
+
f"Something went wrong -- the config file has architectures ({architectures}), but no model head"
|
368 |
+
" output was found. All outputs start with 'hidden'"
|
369 |
+
)
|
370 |
+
max_conversion_output_diff = max(output_differences.values()) if output_differences else 0.0
|
371 |
+
max_conversion_hidden_diff = max(hidden_differences.values())
|
372 |
+
if max_conversion_output_diff > self._max_error or max_conversion_hidden_diff > self._max_error:
|
373 |
+
raise ValueError(
|
374 |
+
"The converted TensorFlow model has different outputs, something went wrong!\n"
|
375 |
+
+ f"\nList of maximum output differences above the threshold ({self._max_error}):\n"
|
376 |
+
+ "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > self._max_error])
|
377 |
+
+ f"\n\nList of maximum hidden layer differences above the threshold ({self._max_error}):\n"
|
378 |
+
+ "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_error])
|
379 |
+
)
|
380 |
+
|
381 |
+
commit_message = "Update TF weights" if self._new_weights else "Add TF weights"
|
382 |
+
if self._push:
|
383 |
+
repo.git_add(auto_lfs_track=True)
|
384 |
+
repo.git_commit(commit_message)
|
385 |
+
repo.git_push(blocking=True) # this prints a progress bar with the upload
|
386 |
+
self._logger.warning(f"TF weights pushed into {self._model_name}")
|
387 |
+
elif not self._no_pr:
|
388 |
+
self._logger.warning("Uploading the weights into a new PR...")
|
389 |
+
commit_descrition = (
|
390 |
+
"Model converted by the [`transformers`' `pt_to_tf`"
|
391 |
+
" CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py). "
|
392 |
+
"All converted model outputs and hidden layers were validated against its PyTorch counterpart.\n\n"
|
393 |
+
f"Maximum crossload output difference={max_crossload_output_diff:.3e}; "
|
394 |
+
f"Maximum crossload hidden layer difference={max_crossload_hidden_diff:.3e};\n"
|
395 |
+
f"Maximum conversion output difference={max_conversion_output_diff:.3e}; "
|
396 |
+
f"Maximum conversion hidden layer difference={max_conversion_hidden_diff:.3e};\n"
|
397 |
+
)
|
398 |
+
if self._max_error > MAX_ERROR:
|
399 |
+
commit_descrition += (
|
400 |
+
f"\n\nCAUTION: The maximum admissible error was manually increased to {self._max_error}!"
|
401 |
+
)
|
402 |
+
if self._extra_commit_description:
|
403 |
+
commit_descrition += "\n\n" + self._extra_commit_description
|
404 |
+
|
405 |
+
# sharded model -> adds all related files (index and .h5 shards)
|
406 |
+
if os.path.exists(tf_weights_index_path):
|
407 |
+
operations = [
|
408 |
+
CommitOperationAdd(path_in_repo=TF2_WEIGHTS_INDEX_NAME, path_or_fileobj=tf_weights_index_path)
|
409 |
+
]
|
410 |
+
for shard_path in tf.io.gfile.glob(self._local_dir + "/tf_model-*.h5"):
|
411 |
+
operations += [
|
412 |
+
CommitOperationAdd(path_in_repo=os.path.basename(shard_path), path_or_fileobj=shard_path)
|
413 |
+
]
|
414 |
+
else:
|
415 |
+
operations = [CommitOperationAdd(path_in_repo=TF2_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)]
|
416 |
+
|
417 |
+
hub_pr_url = create_commit(
|
418 |
+
repo_id=self._model_name,
|
419 |
+
operations=operations,
|
420 |
+
commit_message=commit_message,
|
421 |
+
commit_description=commit_descrition,
|
422 |
+
repo_type="model",
|
423 |
+
create_pr=True,
|
424 |
+
).pr_url
|
425 |
+
self._logger.warning(f"PR open in {hub_pr_url}")
|
transformers_4_35_0/commands/run.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from argparse import ArgumentParser
|
16 |
+
|
17 |
+
from ..pipelines import Pipeline, PipelineDataFormat, get_supported_tasks, pipeline
|
18 |
+
from ..utils import logging
|
19 |
+
from . import BaseTransformersCLICommand
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
23 |
+
|
24 |
+
|
25 |
+
def try_infer_format_from_ext(path: str):
|
26 |
+
if not path:
|
27 |
+
return "pipe"
|
28 |
+
|
29 |
+
for ext in PipelineDataFormat.SUPPORTED_FORMATS:
|
30 |
+
if path.endswith(ext):
|
31 |
+
return ext
|
32 |
+
|
33 |
+
raise Exception(
|
34 |
+
f"Unable to determine file format from file extension {path}. "
|
35 |
+
f"Please provide the format through --format {PipelineDataFormat.SUPPORTED_FORMATS}"
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
def run_command_factory(args):
|
40 |
+
nlp = pipeline(
|
41 |
+
task=args.task,
|
42 |
+
model=args.model if args.model else None,
|
43 |
+
config=args.config,
|
44 |
+
tokenizer=args.tokenizer,
|
45 |
+
device=args.device,
|
46 |
+
)
|
47 |
+
format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format
|
48 |
+
reader = PipelineDataFormat.from_str(
|
49 |
+
format=format,
|
50 |
+
output_path=args.output,
|
51 |
+
input_path=args.input,
|
52 |
+
column=args.column if args.column else nlp.default_input_names,
|
53 |
+
overwrite=args.overwrite,
|
54 |
+
)
|
55 |
+
return RunCommand(nlp, reader)
|
56 |
+
|
57 |
+
|
58 |
+
class RunCommand(BaseTransformersCLICommand):
|
59 |
+
def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
|
60 |
+
self._nlp = nlp
|
61 |
+
self._reader = reader
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def register_subcommand(parser: ArgumentParser):
|
65 |
+
run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
|
66 |
+
run_parser.add_argument("--task", choices=get_supported_tasks(), help="Task to run")
|
67 |
+
run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
|
68 |
+
run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
|
69 |
+
run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
|
70 |
+
run_parser.add_argument("--config", type=str, help="Name or path to the model's config to instantiate.")
|
71 |
+
run_parser.add_argument(
|
72 |
+
"--tokenizer", type=str, help="Name of the tokenizer to use. (default: same as the model name)"
|
73 |
+
)
|
74 |
+
run_parser.add_argument(
|
75 |
+
"--column",
|
76 |
+
type=str,
|
77 |
+
help="Name of the column to use as input. (For multi columns input as QA use column1,columns2)",
|
78 |
+
)
|
79 |
+
run_parser.add_argument(
|
80 |
+
"--format",
|
81 |
+
type=str,
|
82 |
+
default="infer",
|
83 |
+
choices=PipelineDataFormat.SUPPORTED_FORMATS,
|
84 |
+
help="Input format to read from",
|
85 |
+
)
|
86 |
+
run_parser.add_argument(
|
87 |
+
"--device",
|
88 |
+
type=int,
|
89 |
+
default=-1,
|
90 |
+
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
|
91 |
+
)
|
92 |
+
run_parser.add_argument("--overwrite", action="store_true", help="Allow overwriting the output file.")
|
93 |
+
run_parser.set_defaults(func=run_command_factory)
|
94 |
+
|
95 |
+
def run(self):
|
96 |
+
nlp, outputs = self._nlp, []
|
97 |
+
|
98 |
+
for entry in self._reader:
|
99 |
+
output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry)
|
100 |
+
if isinstance(output, dict):
|
101 |
+
outputs.append(output)
|
102 |
+
else:
|
103 |
+
outputs += output
|
104 |
+
|
105 |
+
# Saving data
|
106 |
+
if self._nlp.binary_output:
|
107 |
+
binary_path = self._reader.save_binary(outputs)
|
108 |
+
logger.warning(f"Current pipeline requires output to be in binary format, saving at {binary_path}")
|
109 |
+
else:
|
110 |
+
self._reader.save(outputs)
|
transformers_4_35_0/commands/serving.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from argparse import ArgumentParser, Namespace
|
16 |
+
from typing import Any, List, Optional
|
17 |
+
|
18 |
+
from ..pipelines import Pipeline, get_supported_tasks, pipeline
|
19 |
+
from ..utils import logging
|
20 |
+
from . import BaseTransformersCLICommand
|
21 |
+
|
22 |
+
|
23 |
+
try:
|
24 |
+
from fastapi import Body, FastAPI, HTTPException
|
25 |
+
from fastapi.routing import APIRoute
|
26 |
+
from pydantic import BaseModel
|
27 |
+
from starlette.responses import JSONResponse
|
28 |
+
from uvicorn import run
|
29 |
+
|
30 |
+
_serve_dependencies_installed = True
|
31 |
+
except (ImportError, AttributeError):
|
32 |
+
BaseModel = object
|
33 |
+
|
34 |
+
def Body(*x, **y):
|
35 |
+
pass
|
36 |
+
|
37 |
+
_serve_dependencies_installed = False
|
38 |
+
|
39 |
+
|
40 |
+
logger = logging.get_logger("transformers-cli/serving")
|
41 |
+
|
42 |
+
|
43 |
+
def serve_command_factory(args: Namespace):
|
44 |
+
"""
|
45 |
+
Factory function used to instantiate serving server from provided command line arguments.
|
46 |
+
|
47 |
+
Returns: ServeCommand
|
48 |
+
"""
|
49 |
+
nlp = pipeline(
|
50 |
+
task=args.task,
|
51 |
+
model=args.model if args.model else None,
|
52 |
+
config=args.config,
|
53 |
+
tokenizer=args.tokenizer,
|
54 |
+
device=args.device,
|
55 |
+
)
|
56 |
+
return ServeCommand(nlp, args.host, args.port, args.workers)
|
57 |
+
|
58 |
+
|
59 |
+
class ServeModelInfoResult(BaseModel):
|
60 |
+
"""
|
61 |
+
Expose model information
|
62 |
+
"""
|
63 |
+
|
64 |
+
infos: dict
|
65 |
+
|
66 |
+
|
67 |
+
class ServeTokenizeResult(BaseModel):
|
68 |
+
"""
|
69 |
+
Tokenize result model
|
70 |
+
"""
|
71 |
+
|
72 |
+
tokens: List[str]
|
73 |
+
tokens_ids: Optional[List[int]]
|
74 |
+
|
75 |
+
|
76 |
+
class ServeDeTokenizeResult(BaseModel):
|
77 |
+
"""
|
78 |
+
DeTokenize result model
|
79 |
+
"""
|
80 |
+
|
81 |
+
text: str
|
82 |
+
|
83 |
+
|
84 |
+
class ServeForwardResult(BaseModel):
|
85 |
+
"""
|
86 |
+
Forward result model
|
87 |
+
"""
|
88 |
+
|
89 |
+
output: Any
|
90 |
+
|
91 |
+
|
92 |
+
class ServeCommand(BaseTransformersCLICommand):
|
93 |
+
@staticmethod
|
94 |
+
def register_subcommand(parser: ArgumentParser):
|
95 |
+
"""
|
96 |
+
Register this command to argparse so it's available for the transformer-cli
|
97 |
+
|
98 |
+
Args:
|
99 |
+
parser: Root parser to register command-specific arguments
|
100 |
+
"""
|
101 |
+
serve_parser = parser.add_parser(
|
102 |
+
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
|
103 |
+
)
|
104 |
+
serve_parser.add_argument(
|
105 |
+
"--task",
|
106 |
+
type=str,
|
107 |
+
choices=get_supported_tasks(),
|
108 |
+
help="The task to run the pipeline on",
|
109 |
+
)
|
110 |
+
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
|
111 |
+
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
|
112 |
+
serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
|
113 |
+
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
|
114 |
+
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
|
115 |
+
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
|
116 |
+
serve_parser.add_argument(
|
117 |
+
"--device",
|
118 |
+
type=int,
|
119 |
+
default=-1,
|
120 |
+
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
|
121 |
+
)
|
122 |
+
serve_parser.set_defaults(func=serve_command_factory)
|
123 |
+
|
124 |
+
def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
|
125 |
+
self._pipeline = pipeline
|
126 |
+
|
127 |
+
self.host = host
|
128 |
+
self.port = port
|
129 |
+
self.workers = workers
|
130 |
+
|
131 |
+
if not _serve_dependencies_installed:
|
132 |
+
raise RuntimeError(
|
133 |
+
"Using serve command requires FastAPI and uvicorn. "
|
134 |
+
'Please install transformers with [serving]: pip install "transformers[serving]".'
|
135 |
+
"Or install FastAPI and uvicorn separately."
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
logger.info(f"Serving model over {host}:{port}")
|
139 |
+
self._app = FastAPI(
|
140 |
+
routes=[
|
141 |
+
APIRoute(
|
142 |
+
"/",
|
143 |
+
self.model_info,
|
144 |
+
response_model=ServeModelInfoResult,
|
145 |
+
response_class=JSONResponse,
|
146 |
+
methods=["GET"],
|
147 |
+
),
|
148 |
+
APIRoute(
|
149 |
+
"/tokenize",
|
150 |
+
self.tokenize,
|
151 |
+
response_model=ServeTokenizeResult,
|
152 |
+
response_class=JSONResponse,
|
153 |
+
methods=["POST"],
|
154 |
+
),
|
155 |
+
APIRoute(
|
156 |
+
"/detokenize",
|
157 |
+
self.detokenize,
|
158 |
+
response_model=ServeDeTokenizeResult,
|
159 |
+
response_class=JSONResponse,
|
160 |
+
methods=["POST"],
|
161 |
+
),
|
162 |
+
APIRoute(
|
163 |
+
"/forward",
|
164 |
+
self.forward,
|
165 |
+
response_model=ServeForwardResult,
|
166 |
+
response_class=JSONResponse,
|
167 |
+
methods=["POST"],
|
168 |
+
),
|
169 |
+
],
|
170 |
+
timeout=600,
|
171 |
+
)
|
172 |
+
|
173 |
+
def run(self):
|
174 |
+
run(self._app, host=self.host, port=self.port, workers=self.workers)
|
175 |
+
|
176 |
+
def model_info(self):
|
177 |
+
return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
|
178 |
+
|
179 |
+
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
|
180 |
+
"""
|
181 |
+
Tokenize the provided input and eventually returns corresponding tokens id: - **text_input**: String to
|
182 |
+
tokenize - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer
|
183 |
+
mapping.
|
184 |
+
"""
|
185 |
+
try:
|
186 |
+
tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
|
187 |
+
|
188 |
+
if return_ids:
|
189 |
+
tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
|
190 |
+
return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
|
191 |
+
else:
|
192 |
+
return ServeTokenizeResult(tokens=tokens_txt)
|
193 |
+
|
194 |
+
except Exception as e:
|
195 |
+
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
196 |
+
|
197 |
+
def detokenize(
|
198 |
+
self,
|
199 |
+
tokens_ids: List[int] = Body(None, embed=True),
|
200 |
+
skip_special_tokens: bool = Body(False, embed=True),
|
201 |
+
cleanup_tokenization_spaces: bool = Body(True, embed=True),
|
202 |
+
):
|
203 |
+
"""
|
204 |
+
Detokenize the provided tokens ids to readable text: - **tokens_ids**: List of tokens ids -
|
205 |
+
**skip_special_tokens**: Flag indicating to not try to decode special tokens - **cleanup_tokenization_spaces**:
|
206 |
+
Flag indicating to remove all leading/trailing spaces and intermediate ones.
|
207 |
+
"""
|
208 |
+
try:
|
209 |
+
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
|
210 |
+
return ServeDeTokenizeResult(model="", text=decoded_str)
|
211 |
+
except Exception as e:
|
212 |
+
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
213 |
+
|
214 |
+
async def forward(self, inputs=Body(None, embed=True)):
|
215 |
+
"""
|
216 |
+
**inputs**: **attention_mask**: **tokens_type_ids**:
|
217 |
+
"""
|
218 |
+
|
219 |
+
# Check we don't have empty string
|
220 |
+
if len(inputs) == 0:
|
221 |
+
return ServeForwardResult(output=[], attention=[])
|
222 |
+
|
223 |
+
try:
|
224 |
+
# Forward through the model
|
225 |
+
output = self._pipeline(inputs)
|
226 |
+
return ServeForwardResult(output=output)
|
227 |
+
except Exception as e:
|
228 |
+
raise HTTPException(500, {"error": str(e)})
|
transformers_4_35_0/commands/train.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
from argparse import ArgumentParser, Namespace
|
17 |
+
|
18 |
+
from ..data import SingleSentenceClassificationProcessor as Processor
|
19 |
+
from ..pipelines import TextClassificationPipeline
|
20 |
+
from ..utils import is_tf_available, is_torch_available, logging
|
21 |
+
from . import BaseTransformersCLICommand
|
22 |
+
|
23 |
+
|
24 |
+
if not is_tf_available() and not is_torch_available():
|
25 |
+
raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
|
26 |
+
|
27 |
+
# TF training parameters
|
28 |
+
USE_XLA = False
|
29 |
+
USE_AMP = False
|
30 |
+
|
31 |
+
|
32 |
+
def train_command_factory(args: Namespace):
|
33 |
+
"""
|
34 |
+
Factory function used to instantiate training command from provided command line arguments.
|
35 |
+
|
36 |
+
Returns: TrainCommand
|
37 |
+
"""
|
38 |
+
return TrainCommand(args)
|
39 |
+
|
40 |
+
|
41 |
+
class TrainCommand(BaseTransformersCLICommand):
|
42 |
+
@staticmethod
|
43 |
+
def register_subcommand(parser: ArgumentParser):
|
44 |
+
"""
|
45 |
+
Register this command to argparse so it's available for the transformer-cli
|
46 |
+
|
47 |
+
Args:
|
48 |
+
parser: Root parser to register command-specific arguments
|
49 |
+
"""
|
50 |
+
train_parser = parser.add_parser("train", help="CLI tool to train a model on a task.")
|
51 |
+
|
52 |
+
train_parser.add_argument(
|
53 |
+
"--train_data",
|
54 |
+
type=str,
|
55 |
+
required=True,
|
56 |
+
help="path to train (and optionally evaluation) dataset as a csv with tab separated labels and sentences.",
|
57 |
+
)
|
58 |
+
train_parser.add_argument(
|
59 |
+
"--column_label", type=int, default=0, help="Column of the dataset csv file with example labels."
|
60 |
+
)
|
61 |
+
train_parser.add_argument(
|
62 |
+
"--column_text", type=int, default=1, help="Column of the dataset csv file with example texts."
|
63 |
+
)
|
64 |
+
train_parser.add_argument(
|
65 |
+
"--column_id", type=int, default=2, help="Column of the dataset csv file with example ids."
|
66 |
+
)
|
67 |
+
train_parser.add_argument(
|
68 |
+
"--skip_first_row", action="store_true", help="Skip the first row of the csv file (headers)."
|
69 |
+
)
|
70 |
+
|
71 |
+
train_parser.add_argument("--validation_data", type=str, default="", help="path to validation dataset.")
|
72 |
+
train_parser.add_argument(
|
73 |
+
"--validation_split",
|
74 |
+
type=float,
|
75 |
+
default=0.1,
|
76 |
+
help="if validation dataset is not provided, fraction of train dataset to use as validation dataset.",
|
77 |
+
)
|
78 |
+
|
79 |
+
train_parser.add_argument("--output", type=str, default="./", help="path to saved the trained model.")
|
80 |
+
|
81 |
+
train_parser.add_argument(
|
82 |
+
"--task", type=str, default="text_classification", help="Task to train the model on."
|
83 |
+
)
|
84 |
+
train_parser.add_argument(
|
85 |
+
"--model", type=str, default="bert-base-uncased", help="Model's name or path to stored model."
|
86 |
+
)
|
87 |
+
train_parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training.")
|
88 |
+
train_parser.add_argument("--valid_batch_size", type=int, default=64, help="Batch size for validation.")
|
89 |
+
train_parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.")
|
90 |
+
train_parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon for Adam optimizer.")
|
91 |
+
train_parser.set_defaults(func=train_command_factory)
|
92 |
+
|
93 |
+
def __init__(self, args: Namespace):
|
94 |
+
self.logger = logging.get_logger("transformers-cli/training")
|
95 |
+
|
96 |
+
self.framework = "tf" if is_tf_available() else "torch"
|
97 |
+
|
98 |
+
os.makedirs(args.output, exist_ok=True)
|
99 |
+
self.output = args.output
|
100 |
+
|
101 |
+
self.column_label = args.column_label
|
102 |
+
self.column_text = args.column_text
|
103 |
+
self.column_id = args.column_id
|
104 |
+
|
105 |
+
self.logger.info(f"Loading {args.task} pipeline for {args.model}")
|
106 |
+
if args.task == "text_classification":
|
107 |
+
self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
|
108 |
+
elif args.task == "token_classification":
|
109 |
+
raise NotImplementedError
|
110 |
+
elif args.task == "question_answering":
|
111 |
+
raise NotImplementedError
|
112 |
+
|
113 |
+
self.logger.info(f"Loading dataset from {args.train_data}")
|
114 |
+
self.train_dataset = Processor.create_from_csv(
|
115 |
+
args.train_data,
|
116 |
+
column_label=args.column_label,
|
117 |
+
column_text=args.column_text,
|
118 |
+
column_id=args.column_id,
|
119 |
+
skip_first_row=args.skip_first_row,
|
120 |
+
)
|
121 |
+
self.valid_dataset = None
|
122 |
+
if args.validation_data:
|
123 |
+
self.logger.info(f"Loading validation dataset from {args.validation_data}")
|
124 |
+
self.valid_dataset = Processor.create_from_csv(
|
125 |
+
args.validation_data,
|
126 |
+
column_label=args.column_label,
|
127 |
+
column_text=args.column_text,
|
128 |
+
column_id=args.column_id,
|
129 |
+
skip_first_row=args.skip_first_row,
|
130 |
+
)
|
131 |
+
|
132 |
+
self.validation_split = args.validation_split
|
133 |
+
self.train_batch_size = args.train_batch_size
|
134 |
+
self.valid_batch_size = args.valid_batch_size
|
135 |
+
self.learning_rate = args.learning_rate
|
136 |
+
self.adam_epsilon = args.adam_epsilon
|
137 |
+
|
138 |
+
def run(self):
|
139 |
+
if self.framework == "tf":
|
140 |
+
return self.run_tf()
|
141 |
+
return self.run_torch()
|
142 |
+
|
143 |
+
def run_torch(self):
|
144 |
+
raise NotImplementedError
|
145 |
+
|
146 |
+
def run_tf(self):
|
147 |
+
self.pipeline.fit(
|
148 |
+
self.train_dataset,
|
149 |
+
validation_data=self.valid_dataset,
|
150 |
+
validation_split=self.validation_split,
|
151 |
+
learning_rate=self.learning_rate,
|
152 |
+
adam_epsilon=self.adam_epsilon,
|
153 |
+
train_batch_size=self.train_batch_size,
|
154 |
+
valid_batch_size=self.valid_batch_size,
|
155 |
+
)
|
156 |
+
|
157 |
+
# Save trained pipeline
|
158 |
+
self.pipeline.save_pretrained(self.output)
|
transformers_4_35_0/commands/transformers_cli.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
|
18 |
+
from .add_new_model import AddNewModelCommand
|
19 |
+
from .add_new_model_like import AddNewModelLikeCommand
|
20 |
+
from .convert import ConvertCommand
|
21 |
+
from .download import DownloadCommand
|
22 |
+
from .env import EnvironmentCommand
|
23 |
+
from .lfs import LfsCommands
|
24 |
+
from .pt_to_tf import PTtoTFCommand
|
25 |
+
from .run import RunCommand
|
26 |
+
from .serving import ServeCommand
|
27 |
+
from .user import UserCommands
|
28 |
+
|
29 |
+
|
30 |
+
def main():
|
31 |
+
parser = ArgumentParser("Transformers CLI tool", usage="transformers-cli <command> [<args>]")
|
32 |
+
commands_parser = parser.add_subparsers(help="transformers-cli command helpers")
|
33 |
+
|
34 |
+
# Register commands
|
35 |
+
ConvertCommand.register_subcommand(commands_parser)
|
36 |
+
DownloadCommand.register_subcommand(commands_parser)
|
37 |
+
EnvironmentCommand.register_subcommand(commands_parser)
|
38 |
+
RunCommand.register_subcommand(commands_parser)
|
39 |
+
ServeCommand.register_subcommand(commands_parser)
|
40 |
+
UserCommands.register_subcommand(commands_parser)
|
41 |
+
AddNewModelCommand.register_subcommand(commands_parser)
|
42 |
+
AddNewModelLikeCommand.register_subcommand(commands_parser)
|
43 |
+
LfsCommands.register_subcommand(commands_parser)
|
44 |
+
PTtoTFCommand.register_subcommand(commands_parser)
|
45 |
+
|
46 |
+
# Let's go
|
47 |
+
args = parser.parse_args()
|
48 |
+
|
49 |
+
if not hasattr(args, "func"):
|
50 |
+
parser.print_help()
|
51 |
+
exit(1)
|
52 |
+
|
53 |
+
# Run
|
54 |
+
service = args.func(args)
|
55 |
+
service.run()
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
main()
|
transformers_4_35_0/commands/user.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import subprocess
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
from typing import List, Union
|
18 |
+
|
19 |
+
from huggingface_hub.hf_api import HfFolder, create_repo, whoami
|
20 |
+
from requests.exceptions import HTTPError
|
21 |
+
|
22 |
+
from . import BaseTransformersCLICommand
|
23 |
+
|
24 |
+
|
25 |
+
class UserCommands(BaseTransformersCLICommand):
|
26 |
+
@staticmethod
|
27 |
+
def register_subcommand(parser: ArgumentParser):
|
28 |
+
login_parser = parser.add_parser("login", help="Log in using the same credentials as on huggingface.co")
|
29 |
+
login_parser.set_defaults(func=lambda args: LoginCommand(args))
|
30 |
+
whoami_parser = parser.add_parser("whoami", help="Find out which huggingface.co account you are logged in as.")
|
31 |
+
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
|
32 |
+
logout_parser = parser.add_parser("logout", help="Log out")
|
33 |
+
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
|
34 |
+
|
35 |
+
# new system: git-based repo system
|
36 |
+
repo_parser = parser.add_parser(
|
37 |
+
"repo",
|
38 |
+
help="Deprecated: use `huggingface-cli` instead. Commands to interact with your huggingface.co repos.",
|
39 |
+
)
|
40 |
+
repo_subparsers = repo_parser.add_subparsers(
|
41 |
+
help="Deprecated: use `huggingface-cli` instead. huggingface.co repos related commands"
|
42 |
+
)
|
43 |
+
repo_create_parser = repo_subparsers.add_parser(
|
44 |
+
"create", help="Deprecated: use `huggingface-cli` instead. Create a new repo on huggingface.co"
|
45 |
+
)
|
46 |
+
repo_create_parser.add_argument(
|
47 |
+
"name",
|
48 |
+
type=str,
|
49 |
+
help="Name for your model's repo. Will be namespaced under your username to build the model id.",
|
50 |
+
)
|
51 |
+
repo_create_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
|
52 |
+
repo_create_parser.add_argument("-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt")
|
53 |
+
repo_create_parser.set_defaults(func=lambda args: RepoCreateCommand(args))
|
54 |
+
|
55 |
+
|
56 |
+
class ANSI:
|
57 |
+
"""
|
58 |
+
Helper for en.wikipedia.org/wiki/ANSI_escape_code
|
59 |
+
"""
|
60 |
+
|
61 |
+
_bold = "\u001b[1m"
|
62 |
+
_red = "\u001b[31m"
|
63 |
+
_gray = "\u001b[90m"
|
64 |
+
_reset = "\u001b[0m"
|
65 |
+
|
66 |
+
@classmethod
|
67 |
+
def bold(cls, s):
|
68 |
+
return f"{cls._bold}{s}{cls._reset}"
|
69 |
+
|
70 |
+
@classmethod
|
71 |
+
def red(cls, s):
|
72 |
+
return f"{cls._bold}{cls._red}{s}{cls._reset}"
|
73 |
+
|
74 |
+
@classmethod
|
75 |
+
def gray(cls, s):
|
76 |
+
return f"{cls._gray}{s}{cls._reset}"
|
77 |
+
|
78 |
+
|
79 |
+
def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
|
80 |
+
"""
|
81 |
+
Inspired by:
|
82 |
+
|
83 |
+
- stackoverflow.com/a/8356620/593036
|
84 |
+
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
|
85 |
+
"""
|
86 |
+
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
|
87 |
+
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
|
88 |
+
lines = []
|
89 |
+
lines.append(row_format.format(*headers))
|
90 |
+
lines.append(row_format.format(*["-" * w for w in col_widths]))
|
91 |
+
for row in rows:
|
92 |
+
lines.append(row_format.format(*row))
|
93 |
+
return "\n".join(lines)
|
94 |
+
|
95 |
+
|
96 |
+
class BaseUserCommand:
|
97 |
+
def __init__(self, args):
|
98 |
+
self.args = args
|
99 |
+
|
100 |
+
|
101 |
+
class LoginCommand(BaseUserCommand):
|
102 |
+
def run(self):
|
103 |
+
print(
|
104 |
+
ANSI.red(
|
105 |
+
"ERROR! `huggingface-cli login` uses an outdated login mechanism "
|
106 |
+
"that is not compatible with the Hugging Face Hub backend anymore. "
|
107 |
+
"Please use `huggingface-cli login instead."
|
108 |
+
)
|
109 |
+
)
|
110 |
+
|
111 |
+
|
112 |
+
class WhoamiCommand(BaseUserCommand):
|
113 |
+
def run(self):
|
114 |
+
print(
|
115 |
+
ANSI.red(
|
116 |
+
"WARNING! `transformers-cli whoami` is deprecated and will be removed in v5. Please use "
|
117 |
+
"`huggingface-cli whoami` instead."
|
118 |
+
)
|
119 |
+
)
|
120 |
+
token = HfFolder.get_token()
|
121 |
+
if token is None:
|
122 |
+
print("Not logged in")
|
123 |
+
exit()
|
124 |
+
try:
|
125 |
+
user, orgs = whoami(token)
|
126 |
+
print(user)
|
127 |
+
if orgs:
|
128 |
+
print(ANSI.bold("orgs: "), ",".join(orgs))
|
129 |
+
except HTTPError as e:
|
130 |
+
print(e)
|
131 |
+
print(ANSI.red(e.response.text))
|
132 |
+
exit(1)
|
133 |
+
|
134 |
+
|
135 |
+
class LogoutCommand(BaseUserCommand):
|
136 |
+
def run(self):
|
137 |
+
print(
|
138 |
+
ANSI.red(
|
139 |
+
"ERROR! `transformers-cli logout` uses an outdated logout mechanism "
|
140 |
+
"that is not compatible with the Hugging Face Hub backend anymore. "
|
141 |
+
"Please use `huggingface-cli logout instead."
|
142 |
+
)
|
143 |
+
)
|
144 |
+
|
145 |
+
|
146 |
+
class RepoCreateCommand(BaseUserCommand):
|
147 |
+
def run(self):
|
148 |
+
print(
|
149 |
+
ANSI.red(
|
150 |
+
"WARNING! Managing repositories through transformers-cli is deprecated. "
|
151 |
+
"Please use `huggingface-cli` instead."
|
152 |
+
)
|
153 |
+
)
|
154 |
+
token = HfFolder.get_token()
|
155 |
+
if token is None:
|
156 |
+
print("Not logged in")
|
157 |
+
exit(1)
|
158 |
+
try:
|
159 |
+
stdout = subprocess.check_output(["git", "--version"]).decode("utf-8")
|
160 |
+
print(ANSI.gray(stdout.strip()))
|
161 |
+
except FileNotFoundError:
|
162 |
+
print("Looks like you do not have git installed, please install.")
|
163 |
+
|
164 |
+
try:
|
165 |
+
stdout = subprocess.check_output(["git-lfs", "--version"]).decode("utf-8")
|
166 |
+
print(ANSI.gray(stdout.strip()))
|
167 |
+
except FileNotFoundError:
|
168 |
+
print(
|
169 |
+
ANSI.red(
|
170 |
+
"Looks like you do not have git-lfs installed, please install."
|
171 |
+
" You can install from https://git-lfs.github.com/."
|
172 |
+
" Then run `git lfs install` (you only have to do this once)."
|
173 |
+
)
|
174 |
+
)
|
175 |
+
print("")
|
176 |
+
|
177 |
+
user, _ = whoami(token)
|
178 |
+
namespace = self.args.organization if self.args.organization is not None else user
|
179 |
+
full_name = f"{namespace}/{self.args.name}"
|
180 |
+
print(f"You are about to create {ANSI.bold(full_name)}")
|
181 |
+
|
182 |
+
if not self.args.yes:
|
183 |
+
choice = input("Proceed? [Y/n] ").lower()
|
184 |
+
if not (choice == "" or choice == "y" or choice == "yes"):
|
185 |
+
print("Abort")
|
186 |
+
exit()
|
187 |
+
try:
|
188 |
+
url = create_repo(token, name=self.args.name, organization=self.args.organization)
|
189 |
+
except HTTPError as e:
|
190 |
+
print(e)
|
191 |
+
print(ANSI.red(e.response.text))
|
192 |
+
exit(1)
|
193 |
+
print("\nYour repo now lives at:")
|
194 |
+
print(f" {ANSI.bold(url)}")
|
195 |
+
print("\nYou can clone it locally with the command below, and commit/push as usual.")
|
196 |
+
print(f"\n git clone {url}")
|
197 |
+
print("")
|
transformers_4_35_0/configuration_utils.py
ADDED
@@ -0,0 +1,1075 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" Configuration base class and utilities."""
|
17 |
+
|
18 |
+
|
19 |
+
import copy
|
20 |
+
import json
|
21 |
+
import os
|
22 |
+
import re
|
23 |
+
import warnings
|
24 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
25 |
+
|
26 |
+
from packaging import version
|
27 |
+
|
28 |
+
from . import __version__
|
29 |
+
from .dynamic_module_utils import custom_object_save
|
30 |
+
from .utils import (
|
31 |
+
CONFIG_NAME,
|
32 |
+
PushToHubMixin,
|
33 |
+
add_model_info_to_auto_map,
|
34 |
+
cached_file,
|
35 |
+
copy_func,
|
36 |
+
download_url,
|
37 |
+
extract_commit_hash,
|
38 |
+
is_remote_url,
|
39 |
+
is_torch_available,
|
40 |
+
logging,
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
logger = logging.get_logger(__name__)
|
45 |
+
|
46 |
+
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
47 |
+
|
48 |
+
|
49 |
+
class PretrainedConfig(PushToHubMixin):
|
50 |
+
# no-format
|
51 |
+
r"""
|
52 |
+
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
|
53 |
+
methods for loading/downloading/saving configurations.
|
54 |
+
|
55 |
+
<Tip>
|
56 |
+
|
57 |
+
A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to
|
58 |
+
initialize a model does **not** load the model weights. It only affects the model's configuration.
|
59 |
+
|
60 |
+
</Tip>
|
61 |
+
|
62 |
+
Class attributes (overridden by derived classes):
|
63 |
+
|
64 |
+
- **model_type** (`str`) -- An identifier for the model type, serialized into the JSON file, and used to recreate
|
65 |
+
the correct object in [`~transformers.AutoConfig`].
|
66 |
+
- **is_composition** (`bool`) -- Whether the config class is composed of multiple sub-configs. In this case the
|
67 |
+
config has to be initialized from two or more configs of type [`~transformers.PretrainedConfig`] like:
|
68 |
+
[`~transformers.EncoderDecoderConfig`] or [`~RagConfig`].
|
69 |
+
- **keys_to_ignore_at_inference** (`List[str]`) -- A list of keys to ignore by default when looking at dictionary
|
70 |
+
outputs of the model during inference.
|
71 |
+
- **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
|
72 |
+
naming of attributes.
|
73 |
+
|
74 |
+
Common attributes (present in all subclasses):
|
75 |
+
|
76 |
+
- **vocab_size** (`int`) -- The number of tokens in the vocabulary, which is also the first dimension of the
|
77 |
+
embeddings matrix (this attribute may be missing for models that don't have a text modality like ViT).
|
78 |
+
- **hidden_size** (`int`) -- The hidden size of the model.
|
79 |
+
- **num_attention_heads** (`int`) -- The number of attention heads used in the multi-head attention layers of the
|
80 |
+
model.
|
81 |
+
- **num_hidden_layers** (`int`) -- The number of blocks in the model.
|
82 |
+
|
83 |
+
Arg:
|
84 |
+
name_or_path (`str`, *optional*, defaults to `""`):
|
85 |
+
Store the string that was passed to [`PreTrainedModel.from_pretrained`] or
|
86 |
+
[`TFPreTrainedModel.from_pretrained`] as `pretrained_model_name_or_path` if the configuration was created
|
87 |
+
with such a method.
|
88 |
+
output_hidden_states (`bool`, *optional*, defaults to `False`):
|
89 |
+
Whether or not the model should return all hidden-states.
|
90 |
+
output_attentions (`bool`, *optional*, defaults to `False`):
|
91 |
+
Whether or not the model should returns all attentions.
|
92 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
93 |
+
Whether or not the model should return a [`~transformers.utils.ModelOutput`] instead of a plain tuple.
|
94 |
+
is_encoder_decoder (`bool`, *optional*, defaults to `False`):
|
95 |
+
Whether the model is used as an encoder/decoder or not.
|
96 |
+
is_decoder (`bool`, *optional*, defaults to `False`):
|
97 |
+
Whether the model is used as decoder or not (in which case it's used as an encoder).
|
98 |
+
cross_attention_hidden_size** (`bool`, *optional*):
|
99 |
+
The hidden size of the cross-attention layer in case the model is used as a decoder in an encoder-decoder
|
100 |
+
setting and the cross-attention hidden dimension differs from `self.config.hidden_size`.
|
101 |
+
add_cross_attention (`bool`, *optional*, defaults to `False`):
|
102 |
+
Whether cross-attention layers should be added to the model. Note, this option is only relevant for models
|
103 |
+
that can be used as decoder models within the [`EncoderDecoderModel`] class, which consists of all models
|
104 |
+
in `AUTO_MODELS_FOR_CAUSAL_LM`.
|
105 |
+
tie_encoder_decoder (`bool`, *optional*, defaults to `False`):
|
106 |
+
Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder
|
107 |
+
and decoder model to have the exact same parameter names.
|
108 |
+
prune_heads (`Dict[int, List[int]]`, *optional*, defaults to `{}`):
|
109 |
+
Pruned heads of the model. The keys are the selected layer indices and the associated values, the list of
|
110 |
+
heads to prune in said layer.
|
111 |
+
|
112 |
+
For instance `{1: [0, 2], 2: [2, 3]}` will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
|
113 |
+
chunk_size_feed_forward (`int`, *optional*, defaults to `0`):
|
114 |
+
The chunk size of all feed forward layers in the residual attention blocks. A chunk size of `0` means that
|
115 |
+
the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes `n` <
|
116 |
+
sequence_length embeddings at a time. For more information on feed forward chunking, see [How does Feed
|
117 |
+
Forward Chunking work?](../glossary.html#feed-forward-chunking).
|
118 |
+
|
119 |
+
> Parameters for sequence generation
|
120 |
+
|
121 |
+
max_length (`int`, *optional*, defaults to 20):
|
122 |
+
Maximum length that will be used by default in the `generate` method of the model.
|
123 |
+
min_length (`int`, *optional*, defaults to 0):
|
124 |
+
Minimum length that will be used by default in the `generate` method of the model.
|
125 |
+
do_sample (`bool`, *optional*, defaults to `False`):
|
126 |
+
Flag that will be used by default in the `generate` method of the model. Whether or not to use sampling ;
|
127 |
+
use greedy decoding otherwise.
|
128 |
+
early_stopping (`bool`, *optional*, defaults to `False`):
|
129 |
+
Flag that will be used by default in the `generate` method of the model. Whether to stop the beam search
|
130 |
+
when at least `num_beams` sentences are finished per batch or not.
|
131 |
+
num_beams (`int`, *optional*, defaults to 1):
|
132 |
+
Number of beams for beam search that will be used by default in the `generate` method of the model. 1 means
|
133 |
+
no beam search.
|
134 |
+
num_beam_groups (`int`, *optional*, defaults to 1):
|
135 |
+
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams
|
136 |
+
that will be used by default in the `generate` method of the model. 1 means no group beam search.
|
137 |
+
diversity_penalty (`float`, *optional*, defaults to 0.0):
|
138 |
+
Value to control diversity for group beam search. that will be used by default in the `generate` method of
|
139 |
+
the model. 0 means no diversity penalty. The higher the penalty, the more diverse are the outputs.
|
140 |
+
temperature (`float`, *optional*, defaults to 1.0):
|
141 |
+
The value used to module the next token probabilities that will be used by default in the `generate` method
|
142 |
+
of the model. Must be strictly positive.
|
143 |
+
top_k (`int`, *optional*, defaults to 50):
|
144 |
+
Number of highest probability vocabulary tokens to keep for top-k-filtering that will be used by default in
|
145 |
+
the `generate` method of the model.
|
146 |
+
top_p (`float`, *optional*, defaults to 1):
|
147 |
+
Value that will be used by default in the `generate` method of the model for `top_p`. If set to float < 1,
|
148 |
+
only the most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.
|
149 |
+
typical_p (`float`, *optional*, defaults to 1):
|
150 |
+
Local typicality measures how similar the conditional probability of predicting a target token next is to
|
151 |
+
the expected conditional probability of predicting a random token next, given the partial text already
|
152 |
+
generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that
|
153 |
+
add up to `typical_p` or higher are kept for generation. See [this
|
154 |
+
paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
|
155 |
+
repetition_penalty (`float`, *optional*, defaults to 1):
|
156 |
+
Parameter for repetition penalty that will be used by default in the `generate` method of the model. 1.0
|
157 |
+
means no penalty.
|
158 |
+
length_penalty (`float`, *optional*, defaults to 1):
|
159 |
+
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
|
160 |
+
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
|
161 |
+
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
|
162 |
+
`length_penalty` < 0.0 encourages shorter sequences.
|
163 |
+
no_repeat_ngram_size (`int`, *optional*, defaults to 0) -- Value that will be used by default in the
|
164 |
+
`generate` method of the model for `no_repeat_ngram_size`. If set to int > 0, all ngrams of that size can
|
165 |
+
only occur once.
|
166 |
+
encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0) -- Value that will be used by
|
167 |
+
default in the `generate` method of the model for `encoder_no_repeat_ngram_size`. If set to int > 0, all
|
168 |
+
ngrams of that size that occur in the `encoder_input_ids` cannot occur in the `decoder_input_ids`.
|
169 |
+
bad_words_ids (`List[int]`, *optional*):
|
170 |
+
List of token ids that are not allowed to be generated that will be used by default in the `generate`
|
171 |
+
method of the model. In order to get the tokens of the words that should not appear in the generated text,
|
172 |
+
use `tokenizer.encode(bad_word, add_prefix_space=True)`.
|
173 |
+
num_return_sequences (`int`, *optional*, defaults to 1):
|
174 |
+
Number of independently computed returned sequences for each element in the batch that will be used by
|
175 |
+
default in the `generate` method of the model.
|
176 |
+
output_scores (`bool`, *optional*, defaults to `False`):
|
177 |
+
Whether the model should return the logits when used for generation.
|
178 |
+
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
179 |
+
Whether the model should return a [`~transformers.utils.ModelOutput`] instead of a `torch.LongTensor`.
|
180 |
+
forced_bos_token_id (`int`, *optional*):
|
181 |
+
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful for
|
182 |
+
multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be the target
|
183 |
+
language token.
|
184 |
+
forced_eos_token_id (`int`, *optional*):
|
185 |
+
The id of the token to force as the last generated token when `max_length` is reached.
|
186 |
+
remove_invalid_values (`bool`, *optional*):
|
187 |
+
Whether to remove possible _nan_ and _inf_ outputs of the model to prevent the generation method to crash.
|
188 |
+
Note that using `remove_invalid_values` can slow down generation.
|
189 |
+
|
190 |
+
> Parameters for fine-tuning tasks
|
191 |
+
|
192 |
+
architectures (`List[str]`, *optional*):
|
193 |
+
Model architectures that can be used with the model pretrained weights.
|
194 |
+
finetuning_task (`str`, *optional*):
|
195 |
+
Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow
|
196 |
+
or PyTorch) checkpoint.
|
197 |
+
id2label (`Dict[int, str]`, *optional*):
|
198 |
+
A map from index (for instance prediction index, or target index) to label.
|
199 |
+
label2id (`Dict[str, int]`, *optional*): A map from label to index for the model.
|
200 |
+
num_labels (`int`, *optional*):
|
201 |
+
Number of labels to use in the last layer added to the model, typically for a classification task.
|
202 |
+
task_specific_params (`Dict[str, Any]`, *optional*):
|
203 |
+
Additional keyword arguments to store for the current task.
|
204 |
+
problem_type (`str`, *optional*):
|
205 |
+
Problem type for `XxxForSequenceClassification` models. Can be one of `"regression"`,
|
206 |
+
`"single_label_classification"` or `"multi_label_classification"`.
|
207 |
+
|
208 |
+
> Parameters linked to the tokenizer
|
209 |
+
|
210 |
+
tokenizer_class (`str`, *optional*):
|
211 |
+
The name of the associated tokenizer class to use (if none is set, will use the tokenizer associated to the
|
212 |
+
model by default).
|
213 |
+
prefix (`str`, *optional*):
|
214 |
+
A specific prompt that should be added at the beginning of each text before calling the model.
|
215 |
+
bos_token_id (`int`, *optional*): The id of the _beginning-of-stream_ token.
|
216 |
+
pad_token_id (`int`, *optional*): The id of the _padding_ token.
|
217 |
+
eos_token_id (`int`, *optional*): The id of the _end-of-stream_ token.
|
218 |
+
decoder_start_token_id (`int`, *optional*):
|
219 |
+
If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token.
|
220 |
+
sep_token_id (`int`, *optional*): The id of the _separation_ token.
|
221 |
+
|
222 |
+
> PyTorch specific parameters
|
223 |
+
|
224 |
+
torchscript (`bool`, *optional*, defaults to `False`):
|
225 |
+
Whether or not the model should be used with Torchscript.
|
226 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
|
227 |
+
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
|
228 |
+
model has a output word embedding layer.
|
229 |
+
torch_dtype (`str`, *optional*):
|
230 |
+
The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
|
231 |
+
(which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
|
232 |
+
model is `float16`, ideally we want to load it back using the minimal amount of memory needed to load
|
233 |
+
`float16` weights. Since the config object is stored in plain text, this attribute contains just the
|
234 |
+
floating type string without the `torch.` prefix. For example, for `torch.float16` ``torch_dtype` is the
|
235 |
+
`"float16"` string.
|
236 |
+
|
237 |
+
This attribute is currently not being used during model loading time, but this may change in the future
|
238 |
+
versions. But we can already start preparing for the future by saving the dtype with save_pretrained.
|
239 |
+
|
240 |
+
> TensorFlow specific parameters
|
241 |
+
|
242 |
+
use_bfloat16 (`bool`, *optional*, defaults to `False`):
|
243 |
+
Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models).
|
244 |
+
tf_legacy_loss (`bool`, *optional*, defaults to `False`):
|
245 |
+
Whether the model should use legacy TensorFlow losses. Legacy losses have variable output shapes and may
|
246 |
+
not be XLA-compatible. This option is here for backward compatibility and will be removed in Transformers
|
247 |
+
v5.
|
248 |
+
"""
|
249 |
+
model_type: str = ""
|
250 |
+
is_composition: bool = False
|
251 |
+
attribute_map: Dict[str, str] = {}
|
252 |
+
_auto_class: Optional[str] = None
|
253 |
+
|
254 |
+
def __setattr__(self, key, value):
|
255 |
+
if key in super().__getattribute__("attribute_map"):
|
256 |
+
key = super().__getattribute__("attribute_map")[key]
|
257 |
+
super().__setattr__(key, value)
|
258 |
+
|
259 |
+
def __getattribute__(self, key):
|
260 |
+
if key != "attribute_map" and key in super().__getattribute__("attribute_map"):
|
261 |
+
key = super().__getattribute__("attribute_map")[key]
|
262 |
+
return super().__getattribute__(key)
|
263 |
+
|
264 |
+
def __init__(self, **kwargs):
|
265 |
+
# Attributes with defaults
|
266 |
+
self.return_dict = kwargs.pop("return_dict", True)
|
267 |
+
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
|
268 |
+
self.output_attentions = kwargs.pop("output_attentions", False)
|
269 |
+
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
270 |
+
self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
|
271 |
+
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
272 |
+
self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False) # Only used by TensorFlow models
|
273 |
+
self.pruned_heads = kwargs.pop("pruned_heads", {})
|
274 |
+
self.tie_word_embeddings = kwargs.pop(
|
275 |
+
"tie_word_embeddings", True
|
276 |
+
) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
|
277 |
+
|
278 |
+
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
279 |
+
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
|
280 |
+
self.is_decoder = kwargs.pop("is_decoder", False)
|
281 |
+
self.cross_attention_hidden_size = kwargs.pop("cross_attention_hidden_size", None)
|
282 |
+
self.add_cross_attention = kwargs.pop("add_cross_attention", False)
|
283 |
+
self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
|
284 |
+
|
285 |
+
# Parameters for sequence generation
|
286 |
+
self.max_length = kwargs.pop("max_length", 20)
|
287 |
+
self.min_length = kwargs.pop("min_length", 0)
|
288 |
+
self.do_sample = kwargs.pop("do_sample", False)
|
289 |
+
self.early_stopping = kwargs.pop("early_stopping", False)
|
290 |
+
self.num_beams = kwargs.pop("num_beams", 1)
|
291 |
+
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
|
292 |
+
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
|
293 |
+
self.temperature = kwargs.pop("temperature", 1.0)
|
294 |
+
self.top_k = kwargs.pop("top_k", 50)
|
295 |
+
self.top_p = kwargs.pop("top_p", 1.0)
|
296 |
+
self.typical_p = kwargs.pop("typical_p", 1.0)
|
297 |
+
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
298 |
+
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
299 |
+
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
300 |
+
self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
|
301 |
+
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
302 |
+
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
303 |
+
self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
|
304 |
+
self.output_scores = kwargs.pop("output_scores", False)
|
305 |
+
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
|
306 |
+
self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
|
307 |
+
self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
|
308 |
+
self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
|
309 |
+
self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
|
310 |
+
self.suppress_tokens = kwargs.pop("suppress_tokens", None)
|
311 |
+
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
|
312 |
+
|
313 |
+
# Fine-tuning task arguments
|
314 |
+
self.architectures = kwargs.pop("architectures", None)
|
315 |
+
self.finetuning_task = kwargs.pop("finetuning_task", None)
|
316 |
+
self.id2label = kwargs.pop("id2label", None)
|
317 |
+
self.label2id = kwargs.pop("label2id", None)
|
318 |
+
if self.label2id is not None and not isinstance(self.label2id, dict):
|
319 |
+
raise ValueError("Argument label2id should be a dictionary.")
|
320 |
+
if self.id2label is not None:
|
321 |
+
if not isinstance(self.id2label, dict):
|
322 |
+
raise ValueError("Argument id2label should be a dictionary.")
|
323 |
+
num_labels = kwargs.pop("num_labels", None)
|
324 |
+
if num_labels is not None and len(self.id2label) != num_labels:
|
325 |
+
logger.warning(
|
326 |
+
f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
|
327 |
+
f"{self.id2label}. The number of labels wil be overwritten to {self.num_labels}."
|
328 |
+
)
|
329 |
+
self.id2label = {int(key): value for key, value in self.id2label.items()}
|
330 |
+
# Keys are always strings in JSON so convert ids to int here.
|
331 |
+
else:
|
332 |
+
self.num_labels = kwargs.pop("num_labels", 2)
|
333 |
+
|
334 |
+
if self.torch_dtype is not None and isinstance(self.torch_dtype, str):
|
335 |
+
# we will start using self.torch_dtype in v5, but to be consistent with
|
336 |
+
# from_pretrained's torch_dtype arg convert it to an actual torch.dtype object
|
337 |
+
if is_torch_available():
|
338 |
+
import torch
|
339 |
+
|
340 |
+
self.torch_dtype = getattr(torch, self.torch_dtype)
|
341 |
+
|
342 |
+
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
|
343 |
+
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
|
344 |
+
self.prefix = kwargs.pop("prefix", None)
|
345 |
+
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
346 |
+
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
347 |
+
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
348 |
+
self.sep_token_id = kwargs.pop("sep_token_id", None)
|
349 |
+
|
350 |
+
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
|
351 |
+
|
352 |
+
# task specific arguments
|
353 |
+
self.task_specific_params = kwargs.pop("task_specific_params", None)
|
354 |
+
|
355 |
+
# regression / multi-label classification
|
356 |
+
self.problem_type = kwargs.pop("problem_type", None)
|
357 |
+
allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification")
|
358 |
+
if self.problem_type is not None and self.problem_type not in allowed_problem_types:
|
359 |
+
raise ValueError(
|
360 |
+
f"The config parameter `problem_type` was not understood: received {self.problem_type} "
|
361 |
+
"but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
|
362 |
+
)
|
363 |
+
|
364 |
+
# TPU arguments
|
365 |
+
if kwargs.pop("xla_device", None) is not None:
|
366 |
+
logger.warning(
|
367 |
+
"The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can "
|
368 |
+
"safely remove it from your `config.json` file."
|
369 |
+
)
|
370 |
+
|
371 |
+
# Name or path to the pretrained checkpoint
|
372 |
+
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
373 |
+
# Config hash
|
374 |
+
self._commit_hash = kwargs.pop("_commit_hash", None)
|
375 |
+
|
376 |
+
# Drop the transformers version info
|
377 |
+
self.transformers_version = kwargs.pop("transformers_version", None)
|
378 |
+
|
379 |
+
# Deal with gradient checkpointing
|
380 |
+
if kwargs.get("gradient_checkpointing", False):
|
381 |
+
warnings.warn(
|
382 |
+
"Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
|
383 |
+
"Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the "
|
384 |
+
"`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`."
|
385 |
+
)
|
386 |
+
|
387 |
+
# Additional attributes without default values
|
388 |
+
for key, value in kwargs.items():
|
389 |
+
try:
|
390 |
+
setattr(self, key, value)
|
391 |
+
except AttributeError as err:
|
392 |
+
logger.error(f"Can't set {key} with value {value} for {self}")
|
393 |
+
raise err
|
394 |
+
|
395 |
+
@property
|
396 |
+
def name_or_path(self) -> str:
|
397 |
+
return getattr(self, "_name_or_path", None)
|
398 |
+
|
399 |
+
@name_or_path.setter
|
400 |
+
def name_or_path(self, value):
|
401 |
+
self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding)
|
402 |
+
|
403 |
+
@property
|
404 |
+
def use_return_dict(self) -> bool:
|
405 |
+
"""
|
406 |
+
`bool`: Whether or not return [`~utils.ModelOutput`] instead of tuples.
|
407 |
+
"""
|
408 |
+
# If torchscript is set, force `return_dict=False` to avoid jit errors
|
409 |
+
return self.return_dict and not self.torchscript
|
410 |
+
|
411 |
+
@property
|
412 |
+
def num_labels(self) -> int:
|
413 |
+
"""
|
414 |
+
`int`: The number of labels for classification models.
|
415 |
+
"""
|
416 |
+
return len(self.id2label)
|
417 |
+
|
418 |
+
@num_labels.setter
|
419 |
+
def num_labels(self, num_labels: int):
|
420 |
+
if not hasattr(self, "id2label") or self.id2label is None or len(self.id2label) != num_labels:
|
421 |
+
self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
|
422 |
+
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
|
423 |
+
|
424 |
+
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
425 |
+
"""
|
426 |
+
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
427 |
+
[`~PretrainedConfig.from_pretrained`] class method.
|
428 |
+
|
429 |
+
Args:
|
430 |
+
save_directory (`str` or `os.PathLike`):
|
431 |
+
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
432 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
433 |
+
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
434 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
435 |
+
namespace).
|
436 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
437 |
+
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
438 |
+
"""
|
439 |
+
self._set_token_in_kwargs(kwargs)
|
440 |
+
|
441 |
+
if os.path.isfile(save_directory):
|
442 |
+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
443 |
+
|
444 |
+
os.makedirs(save_directory, exist_ok=True)
|
445 |
+
|
446 |
+
if push_to_hub:
|
447 |
+
commit_message = kwargs.pop("commit_message", None)
|
448 |
+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
449 |
+
repo_id = self._create_repo(repo_id, **kwargs)
|
450 |
+
files_timestamps = self._get_files_timestamps(save_directory)
|
451 |
+
|
452 |
+
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
|
453 |
+
# loaded from the Hub.
|
454 |
+
if self._auto_class is not None:
|
455 |
+
custom_object_save(self, save_directory, config=self)
|
456 |
+
|
457 |
+
# If we save using the predefined names, we can load using `from_pretrained`
|
458 |
+
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
459 |
+
|
460 |
+
self.to_json_file(output_config_file, use_diff=True)
|
461 |
+
logger.info(f"Configuration saved in {output_config_file}")
|
462 |
+
|
463 |
+
if push_to_hub:
|
464 |
+
self._upload_modified_files(
|
465 |
+
save_directory,
|
466 |
+
repo_id,
|
467 |
+
files_timestamps,
|
468 |
+
commit_message=commit_message,
|
469 |
+
token=kwargs.get("token"),
|
470 |
+
)
|
471 |
+
|
472 |
+
@staticmethod
|
473 |
+
def _set_token_in_kwargs(kwargs, token=None):
|
474 |
+
"""Temporary method to deal with `token` and `use_auth_token`.
|
475 |
+
|
476 |
+
This method is to avoid apply the same changes in all model config classes that overwrite `from_pretrained`.
|
477 |
+
|
478 |
+
Need to clean up `use_auth_token` in a follow PR.
|
479 |
+
"""
|
480 |
+
# Some model config classes like CLIP define their own `from_pretrained` without the new argument `token` yet.
|
481 |
+
if token is None:
|
482 |
+
token = kwargs.pop("token", None)
|
483 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
484 |
+
|
485 |
+
if use_auth_token is not None:
|
486 |
+
warnings.warn(
|
487 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
488 |
+
)
|
489 |
+
if token is not None:
|
490 |
+
raise ValueError(
|
491 |
+
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
492 |
+
)
|
493 |
+
token = use_auth_token
|
494 |
+
|
495 |
+
if token is not None:
|
496 |
+
kwargs["token"] = token
|
497 |
+
|
498 |
+
@classmethod
|
499 |
+
def from_pretrained(
|
500 |
+
cls,
|
501 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
502 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
503 |
+
force_download: bool = False,
|
504 |
+
local_files_only: bool = False,
|
505 |
+
token: Optional[Union[str, bool]] = None,
|
506 |
+
revision: str = "main",
|
507 |
+
**kwargs,
|
508 |
+
) -> "PretrainedConfig":
|
509 |
+
r"""
|
510 |
+
Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration.
|
511 |
+
|
512 |
+
Args:
|
513 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
514 |
+
This can be either:
|
515 |
+
|
516 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
517 |
+
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
|
518 |
+
namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
|
519 |
+
- a path to a *directory* containing a configuration file saved using the
|
520 |
+
[`~PretrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
|
521 |
+
- a path or url to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`.
|
522 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
523 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
524 |
+
standard cache should not be used.
|
525 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
526 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if
|
527 |
+
they exist.
|
528 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
529 |
+
Whether or not to delete incompletely received file. Attempts to resume the download if such a file
|
530 |
+
exists.
|
531 |
+
proxies (`Dict[str, str]`, *optional*):
|
532 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
533 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
534 |
+
token (`str` or `bool`, *optional*):
|
535 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
|
536 |
+
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
|
537 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
538 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
539 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
540 |
+
identifier allowed by git.
|
541 |
+
|
542 |
+
<Tip>
|
543 |
+
|
544 |
+
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>".
|
545 |
+
|
546 |
+
</Tip>
|
547 |
+
|
548 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
549 |
+
If `False`, then this function returns just the final configuration object.
|
550 |
+
|
551 |
+
If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
|
552 |
+
dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
|
553 |
+
part of `kwargs` which has not been used to update `config` and is otherwise ignored.
|
554 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
555 |
+
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
556 |
+
specify the folder name here.
|
557 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
558 |
+
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
|
559 |
+
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
|
560 |
+
by the `return_unused_kwargs` keyword parameter.
|
561 |
+
|
562 |
+
Returns:
|
563 |
+
[`PretrainedConfig`]: The configuration object instantiated from this pretrained model.
|
564 |
+
|
565 |
+
Examples:
|
566 |
+
|
567 |
+
```python
|
568 |
+
# We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
|
569 |
+
# derived class: BertConfig
|
570 |
+
config = BertConfig.from_pretrained(
|
571 |
+
"bert-base-uncased"
|
572 |
+
) # Download configuration from huggingface.co and cache.
|
573 |
+
config = BertConfig.from_pretrained(
|
574 |
+
"./test/saved_model/"
|
575 |
+
) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
|
576 |
+
config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
|
577 |
+
config = BertConfig.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
|
578 |
+
assert config.output_attentions == True
|
579 |
+
config, unused_kwargs = BertConfig.from_pretrained(
|
580 |
+
"bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
|
581 |
+
)
|
582 |
+
assert config.output_attentions == True
|
583 |
+
assert unused_kwargs == {"foo": False}
|
584 |
+
```"""
|
585 |
+
kwargs["cache_dir"] = cache_dir
|
586 |
+
kwargs["force_download"] = force_download
|
587 |
+
kwargs["local_files_only"] = local_files_only
|
588 |
+
kwargs["revision"] = revision
|
589 |
+
|
590 |
+
cls._set_token_in_kwargs(kwargs, token)
|
591 |
+
|
592 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
593 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
594 |
+
logger.warning(
|
595 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
596 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
597 |
+
)
|
598 |
+
|
599 |
+
return cls.from_dict(config_dict, **kwargs)
|
600 |
+
|
601 |
+
@classmethod
|
602 |
+
def get_config_dict(
|
603 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
604 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
605 |
+
"""
|
606 |
+
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
|
607 |
+
[`PretrainedConfig`] using `from_dict`.
|
608 |
+
|
609 |
+
Parameters:
|
610 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
611 |
+
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
|
612 |
+
|
613 |
+
Returns:
|
614 |
+
`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
|
615 |
+
|
616 |
+
"""
|
617 |
+
cls._set_token_in_kwargs(kwargs)
|
618 |
+
|
619 |
+
original_kwargs = copy.deepcopy(kwargs)
|
620 |
+
# Get config dict associated with the base config file
|
621 |
+
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
|
622 |
+
if "_commit_hash" in config_dict:
|
623 |
+
original_kwargs["_commit_hash"] = config_dict["_commit_hash"]
|
624 |
+
|
625 |
+
# That config file may point us toward another config file to use.
|
626 |
+
if "configuration_files" in config_dict:
|
627 |
+
configuration_file = get_configuration_file(config_dict["configuration_files"])
|
628 |
+
config_dict, kwargs = cls._get_config_dict(
|
629 |
+
pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs
|
630 |
+
)
|
631 |
+
|
632 |
+
return config_dict, kwargs
|
633 |
+
|
634 |
+
@classmethod
|
635 |
+
def _get_config_dict(
|
636 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
637 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
638 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
639 |
+
force_download = kwargs.pop("force_download", False)
|
640 |
+
resume_download = kwargs.pop("resume_download", False)
|
641 |
+
proxies = kwargs.pop("proxies", None)
|
642 |
+
token = kwargs.pop("token", None)
|
643 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
644 |
+
revision = kwargs.pop("revision", None)
|
645 |
+
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
646 |
+
subfolder = kwargs.pop("subfolder", "")
|
647 |
+
from_pipeline = kwargs.pop("_from_pipeline", None)
|
648 |
+
from_auto_class = kwargs.pop("_from_auto", False)
|
649 |
+
commit_hash = kwargs.pop("_commit_hash", None)
|
650 |
+
|
651 |
+
if trust_remote_code is True:
|
652 |
+
logger.warning(
|
653 |
+
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
|
654 |
+
" ignored."
|
655 |
+
)
|
656 |
+
|
657 |
+
user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
|
658 |
+
if from_pipeline is not None:
|
659 |
+
user_agent["using_pipeline"] = from_pipeline
|
660 |
+
|
661 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
662 |
+
|
663 |
+
is_local = os.path.isdir(pretrained_model_name_or_path)
|
664 |
+
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
|
665 |
+
# Special case when pretrained_model_name_or_path is a local file
|
666 |
+
resolved_config_file = pretrained_model_name_or_path
|
667 |
+
is_local = True
|
668 |
+
elif is_remote_url(pretrained_model_name_or_path):
|
669 |
+
configuration_file = pretrained_model_name_or_path
|
670 |
+
resolved_config_file = download_url(pretrained_model_name_or_path)
|
671 |
+
else:
|
672 |
+
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
|
673 |
+
|
674 |
+
try:
|
675 |
+
# Load from local folder or from cache or download from model Hub and cache
|
676 |
+
resolved_config_file = cached_file(
|
677 |
+
pretrained_model_name_or_path,
|
678 |
+
configuration_file,
|
679 |
+
cache_dir=cache_dir,
|
680 |
+
force_download=force_download,
|
681 |
+
proxies=proxies,
|
682 |
+
resume_download=resume_download,
|
683 |
+
local_files_only=local_files_only,
|
684 |
+
token=token,
|
685 |
+
user_agent=user_agent,
|
686 |
+
revision=revision,
|
687 |
+
subfolder=subfolder,
|
688 |
+
_commit_hash=commit_hash,
|
689 |
+
)
|
690 |
+
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
|
691 |
+
except EnvironmentError:
|
692 |
+
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
|
693 |
+
# the original exception.
|
694 |
+
raise
|
695 |
+
except Exception:
|
696 |
+
# For any other exception, we throw a generic error.
|
697 |
+
raise EnvironmentError(
|
698 |
+
f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
|
699 |
+
" from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
|
700 |
+
f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
|
701 |
+
f" containing a {configuration_file} file"
|
702 |
+
)
|
703 |
+
|
704 |
+
try:
|
705 |
+
# Load config dict
|
706 |
+
config_dict = cls._dict_from_json_file(resolved_config_file)
|
707 |
+
config_dict["_commit_hash"] = commit_hash
|
708 |
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
709 |
+
raise EnvironmentError(
|
710 |
+
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
|
711 |
+
)
|
712 |
+
|
713 |
+
if is_local:
|
714 |
+
logger.info(f"loading configuration file {resolved_config_file}")
|
715 |
+
else:
|
716 |
+
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
|
717 |
+
|
718 |
+
if "auto_map" in config_dict and not is_local:
|
719 |
+
config_dict["auto_map"] = add_model_info_to_auto_map(
|
720 |
+
config_dict["auto_map"], pretrained_model_name_or_path
|
721 |
+
)
|
722 |
+
return config_dict, kwargs
|
723 |
+
|
724 |
+
@classmethod
|
725 |
+
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
|
726 |
+
"""
|
727 |
+
Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
|
728 |
+
|
729 |
+
Args:
|
730 |
+
config_dict (`Dict[str, Any]`):
|
731 |
+
Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
|
732 |
+
retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
|
733 |
+
kwargs (`Dict[str, Any]`):
|
734 |
+
Additional parameters from which to initialize the configuration object.
|
735 |
+
|
736 |
+
Returns:
|
737 |
+
[`PretrainedConfig`]: The configuration object instantiated from those parameters.
|
738 |
+
"""
|
739 |
+
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
740 |
+
# Those arguments may be passed along for our internal telemetry.
|
741 |
+
# We remove them so they don't appear in `return_unused_kwargs`.
|
742 |
+
kwargs.pop("_from_auto", None)
|
743 |
+
kwargs.pop("_from_pipeline", None)
|
744 |
+
# The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
|
745 |
+
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
|
746 |
+
kwargs["_commit_hash"] = config_dict["_commit_hash"]
|
747 |
+
|
748 |
+
config = cls(**config_dict)
|
749 |
+
|
750 |
+
if hasattr(config, "pruned_heads"):
|
751 |
+
config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
|
752 |
+
|
753 |
+
# Update config with kwargs if needed
|
754 |
+
if "num_labels" in kwargs and "id2label" in kwargs:
|
755 |
+
num_labels = kwargs["num_labels"]
|
756 |
+
id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
|
757 |
+
if len(id2label) != num_labels:
|
758 |
+
raise ValueError(
|
759 |
+
f"You passed along `num_labels={num_labels }` with an incompatible id to label map: "
|
760 |
+
f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
|
761 |
+
"one of them."
|
762 |
+
)
|
763 |
+
to_remove = []
|
764 |
+
for key, value in kwargs.items():
|
765 |
+
if hasattr(config, key):
|
766 |
+
current_attr = getattr(config, key)
|
767 |
+
# To authorize passing a custom subconfig as kwarg in models that have nested configs.
|
768 |
+
if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict):
|
769 |
+
value = current_attr.__class__(**value)
|
770 |
+
setattr(config, key, value)
|
771 |
+
if key != "torch_dtype":
|
772 |
+
to_remove.append(key)
|
773 |
+
for key in to_remove:
|
774 |
+
kwargs.pop(key, None)
|
775 |
+
|
776 |
+
logger.info(f"Model config {config}")
|
777 |
+
if return_unused_kwargs:
|
778 |
+
return config, kwargs
|
779 |
+
else:
|
780 |
+
return config
|
781 |
+
|
782 |
+
@classmethod
|
783 |
+
def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig":
|
784 |
+
"""
|
785 |
+
Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters.
|
786 |
+
|
787 |
+
Args:
|
788 |
+
json_file (`str` or `os.PathLike`):
|
789 |
+
Path to the JSON file containing the parameters.
|
790 |
+
|
791 |
+
Returns:
|
792 |
+
[`PretrainedConfig`]: The configuration object instantiated from that JSON file.
|
793 |
+
|
794 |
+
"""
|
795 |
+
config_dict = cls._dict_from_json_file(json_file)
|
796 |
+
return cls(**config_dict)
|
797 |
+
|
798 |
+
@classmethod
|
799 |
+
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
800 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
801 |
+
text = reader.read()
|
802 |
+
return json.loads(text)
|
803 |
+
|
804 |
+
def __eq__(self, other):
|
805 |
+
return isinstance(other, PretrainedConfig) and (self.__dict__ == other.__dict__)
|
806 |
+
|
807 |
+
def __repr__(self):
|
808 |
+
return f"{self.__class__.__name__} {self.to_json_string()}"
|
809 |
+
|
810 |
+
def to_diff_dict(self) -> Dict[str, Any]:
|
811 |
+
"""
|
812 |
+
Removes all attributes from config which correspond to the default config attributes for better readability and
|
813 |
+
serializes to a Python dictionary.
|
814 |
+
|
815 |
+
Returns:
|
816 |
+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
|
817 |
+
"""
|
818 |
+
config_dict = self.to_dict()
|
819 |
+
|
820 |
+
# get the default config dict
|
821 |
+
default_config_dict = PretrainedConfig().to_dict()
|
822 |
+
|
823 |
+
# get class specific config dict
|
824 |
+
class_config_dict = self.__class__().to_dict() if not self.is_composition else {}
|
825 |
+
|
826 |
+
serializable_config_dict = {}
|
827 |
+
|
828 |
+
# only serialize values that differ from the default config
|
829 |
+
for key, value in config_dict.items():
|
830 |
+
if (
|
831 |
+
isinstance(getattr(self, key, None), PretrainedConfig)
|
832 |
+
and key in class_config_dict
|
833 |
+
and isinstance(class_config_dict[key], dict)
|
834 |
+
):
|
835 |
+
# For nested configs we need to clean the diff recursively
|
836 |
+
diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None))
|
837 |
+
if "model_type" in value:
|
838 |
+
# Needs to be set even if it's not in the diff
|
839 |
+
diff["model_type"] = value["model_type"]
|
840 |
+
if len(diff) > 0:
|
841 |
+
serializable_config_dict[key] = diff
|
842 |
+
elif (
|
843 |
+
key not in default_config_dict
|
844 |
+
or key == "transformers_version"
|
845 |
+
or value != default_config_dict[key]
|
846 |
+
or (key in class_config_dict and value != class_config_dict[key])
|
847 |
+
):
|
848 |
+
serializable_config_dict[key] = value
|
849 |
+
|
850 |
+
if hasattr(self, "quantization_config"):
|
851 |
+
serializable_config_dict["quantization_config"] = (
|
852 |
+
self.quantization_config.to_dict()
|
853 |
+
if not isinstance(self.quantization_config, dict)
|
854 |
+
else self.quantization_config
|
855 |
+
)
|
856 |
+
|
857 |
+
self.dict_torch_dtype_to_str(serializable_config_dict)
|
858 |
+
|
859 |
+
if "_flash_attn_2_enabled" in serializable_config_dict:
|
860 |
+
del serializable_config_dict["_flash_attn_2_enabled"]
|
861 |
+
|
862 |
+
return serializable_config_dict
|
863 |
+
|
864 |
+
def to_dict(self) -> Dict[str, Any]:
|
865 |
+
"""
|
866 |
+
Serializes this instance to a Python dictionary.
|
867 |
+
|
868 |
+
Returns:
|
869 |
+
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
870 |
+
"""
|
871 |
+
output = copy.deepcopy(self.__dict__)
|
872 |
+
if hasattr(self.__class__, "model_type"):
|
873 |
+
output["model_type"] = self.__class__.model_type
|
874 |
+
if "_auto_class" in output:
|
875 |
+
del output["_auto_class"]
|
876 |
+
if "_commit_hash" in output:
|
877 |
+
del output["_commit_hash"]
|
878 |
+
if "_flash_attn_2_enabled" in output:
|
879 |
+
del output["_flash_attn_2_enabled"]
|
880 |
+
|
881 |
+
# Transformers version when serializing the model
|
882 |
+
output["transformers_version"] = __version__
|
883 |
+
|
884 |
+
for key, value in output.items():
|
885 |
+
# Deal with nested configs like CLIP
|
886 |
+
if isinstance(value, PretrainedConfig):
|
887 |
+
value = value.to_dict()
|
888 |
+
del value["transformers_version"]
|
889 |
+
|
890 |
+
output[key] = value
|
891 |
+
|
892 |
+
if hasattr(self, "quantization_config"):
|
893 |
+
output["quantization_config"] = (
|
894 |
+
self.quantization_config.to_dict()
|
895 |
+
if not isinstance(self.quantization_config, dict)
|
896 |
+
else self.quantization_config
|
897 |
+
)
|
898 |
+
|
899 |
+
self.dict_torch_dtype_to_str(output)
|
900 |
+
|
901 |
+
return output
|
902 |
+
|
903 |
+
def to_json_string(self, use_diff: bool = True) -> str:
|
904 |
+
"""
|
905 |
+
Serializes this instance to a JSON string.
|
906 |
+
|
907 |
+
Args:
|
908 |
+
use_diff (`bool`, *optional*, defaults to `True`):
|
909 |
+
If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
|
910 |
+
is serialized to JSON string.
|
911 |
+
|
912 |
+
Returns:
|
913 |
+
`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
914 |
+
"""
|
915 |
+
if use_diff is True:
|
916 |
+
config_dict = self.to_diff_dict()
|
917 |
+
else:
|
918 |
+
config_dict = self.to_dict()
|
919 |
+
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
920 |
+
|
921 |
+
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
|
922 |
+
"""
|
923 |
+
Save this instance to a JSON file.
|
924 |
+
|
925 |
+
Args:
|
926 |
+
json_file_path (`str` or `os.PathLike`):
|
927 |
+
Path to the JSON file in which this configuration instance's parameters will be saved.
|
928 |
+
use_diff (`bool`, *optional*, defaults to `True`):
|
929 |
+
If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
|
930 |
+
is serialized to JSON file.
|
931 |
+
"""
|
932 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
933 |
+
writer.write(self.to_json_string(use_diff=use_diff))
|
934 |
+
|
935 |
+
def update(self, config_dict: Dict[str, Any]):
|
936 |
+
"""
|
937 |
+
Updates attributes of this class with attributes from `config_dict`.
|
938 |
+
|
939 |
+
Args:
|
940 |
+
config_dict (`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
|
941 |
+
"""
|
942 |
+
for key, value in config_dict.items():
|
943 |
+
setattr(self, key, value)
|
944 |
+
|
945 |
+
def update_from_string(self, update_str: str):
|
946 |
+
"""
|
947 |
+
Updates attributes of this class with attributes from `update_str`.
|
948 |
+
|
949 |
+
The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
|
950 |
+
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
951 |
+
|
952 |
+
The keys to change have to already exist in the config object.
|
953 |
+
|
954 |
+
Args:
|
955 |
+
update_str (`str`): String with attributes that should be updated for this class.
|
956 |
+
|
957 |
+
"""
|
958 |
+
|
959 |
+
d = dict(x.split("=") for x in update_str.split(","))
|
960 |
+
for k, v in d.items():
|
961 |
+
if not hasattr(self, k):
|
962 |
+
raise ValueError(f"key {k} isn't in the original config dict")
|
963 |
+
|
964 |
+
old_v = getattr(self, k)
|
965 |
+
if isinstance(old_v, bool):
|
966 |
+
if v.lower() in ["true", "1", "y", "yes"]:
|
967 |
+
v = True
|
968 |
+
elif v.lower() in ["false", "0", "n", "no"]:
|
969 |
+
v = False
|
970 |
+
else:
|
971 |
+
raise ValueError(f"can't derive true or false from {v} (key {k})")
|
972 |
+
elif isinstance(old_v, int):
|
973 |
+
v = int(v)
|
974 |
+
elif isinstance(old_v, float):
|
975 |
+
v = float(v)
|
976 |
+
elif not isinstance(old_v, str):
|
977 |
+
raise ValueError(
|
978 |
+
f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
|
979 |
+
)
|
980 |
+
|
981 |
+
setattr(self, k, v)
|
982 |
+
|
983 |
+
def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
|
984 |
+
"""
|
985 |
+
Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
|
986 |
+
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
|
987 |
+
string, which can then be stored in the json format.
|
988 |
+
"""
|
989 |
+
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
|
990 |
+
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
|
991 |
+
for value in d.values():
|
992 |
+
if isinstance(value, dict):
|
993 |
+
self.dict_torch_dtype_to_str(value)
|
994 |
+
|
995 |
+
@classmethod
|
996 |
+
def register_for_auto_class(cls, auto_class="AutoConfig"):
|
997 |
+
"""
|
998 |
+
Register this class with a given auto class. This should only be used for custom configurations as the ones in
|
999 |
+
the library are already mapped with `AutoConfig`.
|
1000 |
+
|
1001 |
+
<Tip warning={true}>
|
1002 |
+
|
1003 |
+
This API is experimental and may have some slight breaking changes in the next releases.
|
1004 |
+
|
1005 |
+
</Tip>
|
1006 |
+
|
1007 |
+
Args:
|
1008 |
+
auto_class (`str` or `type`, *optional*, defaults to `"AutoConfig"`):
|
1009 |
+
The auto class to register this new configuration with.
|
1010 |
+
"""
|
1011 |
+
if not isinstance(auto_class, str):
|
1012 |
+
auto_class = auto_class.__name__
|
1013 |
+
|
1014 |
+
import transformers.models.auto as auto_module
|
1015 |
+
|
1016 |
+
if not hasattr(auto_module, auto_class):
|
1017 |
+
raise ValueError(f"{auto_class} is not a valid auto class.")
|
1018 |
+
|
1019 |
+
cls._auto_class = auto_class
|
1020 |
+
|
1021 |
+
|
1022 |
+
def get_configuration_file(configuration_files: List[str]) -> str:
|
1023 |
+
"""
|
1024 |
+
Get the configuration file to use for this version of transformers.
|
1025 |
+
|
1026 |
+
Args:
|
1027 |
+
configuration_files (`List[str]`): The list of available configuration files.
|
1028 |
+
|
1029 |
+
Returns:
|
1030 |
+
`str`: The configuration file to use.
|
1031 |
+
"""
|
1032 |
+
configuration_files_map = {}
|
1033 |
+
for file_name in configuration_files:
|
1034 |
+
search = _re_configuration_file.search(file_name)
|
1035 |
+
if search is not None:
|
1036 |
+
v = search.groups()[0]
|
1037 |
+
configuration_files_map[v] = file_name
|
1038 |
+
available_versions = sorted(configuration_files_map.keys())
|
1039 |
+
|
1040 |
+
# Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
|
1041 |
+
configuration_file = CONFIG_NAME
|
1042 |
+
transformers_version = version.parse(__version__)
|
1043 |
+
for v in available_versions:
|
1044 |
+
if version.parse(v) <= transformers_version:
|
1045 |
+
configuration_file = configuration_files_map[v]
|
1046 |
+
else:
|
1047 |
+
# No point going further since the versions are sorted.
|
1048 |
+
break
|
1049 |
+
|
1050 |
+
return configuration_file
|
1051 |
+
|
1052 |
+
|
1053 |
+
def recursive_diff_dict(dict_a, dict_b, config_obj=None):
|
1054 |
+
"""
|
1055 |
+
Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
|
1056 |
+
values from `dict_a` that are different from values in `dict_b`.
|
1057 |
+
"""
|
1058 |
+
diff = {}
|
1059 |
+
default = config_obj.__class__().to_dict() if config_obj is not None else {}
|
1060 |
+
for key, value in dict_a.items():
|
1061 |
+
obj_value = getattr(config_obj, str(key), None)
|
1062 |
+
if isinstance(obj_value, PretrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
|
1063 |
+
diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
|
1064 |
+
if len(diff_value) > 0:
|
1065 |
+
diff[key] = diff_value
|
1066 |
+
elif key not in dict_b or value != dict_b[key] or key not in default or value != default[key]:
|
1067 |
+
diff[key] = value
|
1068 |
+
return diff
|
1069 |
+
|
1070 |
+
|
1071 |
+
PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
|
1072 |
+
if PretrainedConfig.push_to_hub.__doc__ is not None:
|
1073 |
+
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
|
1074 |
+
object="config", object_class="AutoConfig", object_files="configuration file"
|
1075 |
+
)
|
transformers_4_35_0/convert_graph_to_onnx.py
ADDED
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import warnings
|
16 |
+
from argparse import ArgumentParser
|
17 |
+
from os import listdir, makedirs
|
18 |
+
from pathlib import Path
|
19 |
+
from typing import Dict, List, Optional, Tuple
|
20 |
+
|
21 |
+
from packaging.version import Version, parse
|
22 |
+
|
23 |
+
from transformers.pipelines import Pipeline, pipeline
|
24 |
+
from transformers.tokenization_utils import BatchEncoding
|
25 |
+
from transformers.utils import ModelOutput, is_tf_available, is_torch_available
|
26 |
+
|
27 |
+
|
28 |
+
# This is the minimal required version to
|
29 |
+
# support some ONNX Runtime features
|
30 |
+
ORT_QUANTIZE_MINIMUM_VERSION = parse("1.4.0")
|
31 |
+
|
32 |
+
|
33 |
+
SUPPORTED_PIPELINES = [
|
34 |
+
"feature-extraction",
|
35 |
+
"ner",
|
36 |
+
"sentiment-analysis",
|
37 |
+
"fill-mask",
|
38 |
+
"question-answering",
|
39 |
+
"text-generation",
|
40 |
+
"translation_en_to_fr",
|
41 |
+
"translation_en_to_de",
|
42 |
+
"translation_en_to_ro",
|
43 |
+
]
|
44 |
+
|
45 |
+
|
46 |
+
class OnnxConverterArgumentParser(ArgumentParser):
|
47 |
+
"""
|
48 |
+
Wraps all the script arguments supported to export transformers models to ONNX IR
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self):
|
52 |
+
super().__init__("ONNX Converter")
|
53 |
+
|
54 |
+
self.add_argument(
|
55 |
+
"--pipeline",
|
56 |
+
type=str,
|
57 |
+
choices=SUPPORTED_PIPELINES,
|
58 |
+
default="feature-extraction",
|
59 |
+
)
|
60 |
+
self.add_argument(
|
61 |
+
"--model",
|
62 |
+
type=str,
|
63 |
+
required=True,
|
64 |
+
help="Model's id or path (ex: bert-base-cased)",
|
65 |
+
)
|
66 |
+
self.add_argument("--tokenizer", type=str, help="Tokenizer's id or path (ex: bert-base-cased)")
|
67 |
+
self.add_argument(
|
68 |
+
"--framework",
|
69 |
+
type=str,
|
70 |
+
choices=["pt", "tf"],
|
71 |
+
help="Framework for loading the model",
|
72 |
+
)
|
73 |
+
self.add_argument("--opset", type=int, default=11, help="ONNX opset to use")
|
74 |
+
self.add_argument(
|
75 |
+
"--check-loading",
|
76 |
+
action="store_true",
|
77 |
+
help="Check ONNX is able to load the model",
|
78 |
+
)
|
79 |
+
self.add_argument(
|
80 |
+
"--use-external-format",
|
81 |
+
action="store_true",
|
82 |
+
help="Allow exporting model >= than 2Gb",
|
83 |
+
)
|
84 |
+
self.add_argument(
|
85 |
+
"--quantize",
|
86 |
+
action="store_true",
|
87 |
+
help="Quantize the neural network to be run with int8",
|
88 |
+
)
|
89 |
+
self.add_argument("output")
|
90 |
+
|
91 |
+
|
92 |
+
def generate_identified_filename(filename: Path, identifier: str) -> Path:
|
93 |
+
"""
|
94 |
+
Append a string-identifier at the end (before the extension, if any) to the provided filepath
|
95 |
+
|
96 |
+
Args:
|
97 |
+
filename: pathlib.Path The actual path object we would like to add an identifier suffix
|
98 |
+
identifier: The suffix to add
|
99 |
+
|
100 |
+
Returns: String with concatenated identifier at the end of the filename
|
101 |
+
"""
|
102 |
+
return filename.parent.joinpath(filename.stem + identifier).with_suffix(filename.suffix)
|
103 |
+
|
104 |
+
|
105 |
+
def check_onnxruntime_requirements(minimum_version: Version):
|
106 |
+
"""
|
107 |
+
Check onnxruntime is installed and if the installed version match is recent enough
|
108 |
+
|
109 |
+
Raises:
|
110 |
+
ImportError: If onnxruntime is not installed or too old version is found
|
111 |
+
"""
|
112 |
+
try:
|
113 |
+
import onnxruntime
|
114 |
+
|
115 |
+
# Parse the version of the installed onnxruntime
|
116 |
+
ort_version = parse(onnxruntime.__version__)
|
117 |
+
|
118 |
+
# We require 1.4.0 minimum
|
119 |
+
if ort_version < ORT_QUANTIZE_MINIMUM_VERSION:
|
120 |
+
raise ImportError(
|
121 |
+
f"We found an older version of onnxruntime ({onnxruntime.__version__}) "
|
122 |
+
f"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\n"
|
123 |
+
"Please update onnxruntime by running `pip install --upgrade onnxruntime`"
|
124 |
+
)
|
125 |
+
|
126 |
+
except ImportError:
|
127 |
+
raise ImportError(
|
128 |
+
"onnxruntime doesn't seem to be currently installed. "
|
129 |
+
"Please install the onnxruntime by running `pip install onnxruntime`"
|
130 |
+
" and relaunch the conversion."
|
131 |
+
)
|
132 |
+
|
133 |
+
|
134 |
+
def ensure_valid_input(model, tokens, input_names):
|
135 |
+
"""
|
136 |
+
Ensure inputs are presented in the correct order, without any Non
|
137 |
+
|
138 |
+
Args:
|
139 |
+
model: The model used to forward the input data
|
140 |
+
tokens: BatchEncoding holding the input data
|
141 |
+
input_names: The name of the inputs
|
142 |
+
|
143 |
+
Returns: Tuple
|
144 |
+
|
145 |
+
"""
|
146 |
+
print("Ensuring inputs are in correct order")
|
147 |
+
|
148 |
+
model_args_name = model.forward.__code__.co_varnames
|
149 |
+
model_args, ordered_input_names = [], []
|
150 |
+
for arg_name in model_args_name[1:]: # start at index 1 to skip "self" argument
|
151 |
+
if arg_name in input_names:
|
152 |
+
ordered_input_names.append(arg_name)
|
153 |
+
model_args.append(tokens[arg_name])
|
154 |
+
else:
|
155 |
+
print(f"{arg_name} is not present in the generated input list.")
|
156 |
+
break
|
157 |
+
|
158 |
+
print(f"Generated inputs order: {ordered_input_names}")
|
159 |
+
return ordered_input_names, tuple(model_args)
|
160 |
+
|
161 |
+
|
162 |
+
def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
|
163 |
+
"""
|
164 |
+
Attempt to infer the static vs dynamic axes for each input and output tensors for a specific model
|
165 |
+
|
166 |
+
Args:
|
167 |
+
nlp: The pipeline object holding the model to be exported
|
168 |
+
framework: The framework identifier to dispatch to the correct inference scheme (pt/tf)
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
|
172 |
+
- List of the inferred input variable names
|
173 |
+
- List of the inferred output variable names
|
174 |
+
- Dictionary with input/output variables names as key and shape tensor as value
|
175 |
+
- a BatchEncoding reference which was used to infer all the above information
|
176 |
+
"""
|
177 |
+
|
178 |
+
def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int):
|
179 |
+
if isinstance(tensor, (tuple, list)):
|
180 |
+
return [build_shape_dict(name, t, is_input, seq_len) for t in tensor]
|
181 |
+
|
182 |
+
else:
|
183 |
+
# Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)
|
184 |
+
axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: "batch"}
|
185 |
+
if is_input:
|
186 |
+
if len(tensor.shape) == 2:
|
187 |
+
axes[1] = "sequence"
|
188 |
+
else:
|
189 |
+
raise ValueError(f"Unable to infer tensor axes ({len(tensor.shape)})")
|
190 |
+
else:
|
191 |
+
seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]
|
192 |
+
axes.update({dim: "sequence" for dim in seq_axes})
|
193 |
+
|
194 |
+
print(f"Found {'input' if is_input else 'output'} {name} with shape: {axes}")
|
195 |
+
return axes
|
196 |
+
|
197 |
+
tokens = nlp.tokenizer("This is a sample output", return_tensors=framework)
|
198 |
+
seq_len = tokens.input_ids.shape[-1]
|
199 |
+
outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens)
|
200 |
+
if isinstance(outputs, ModelOutput):
|
201 |
+
outputs = outputs.to_tuple()
|
202 |
+
if not isinstance(outputs, (list, tuple)):
|
203 |
+
outputs = (outputs,)
|
204 |
+
|
205 |
+
# Generate input names & axes
|
206 |
+
input_vars = list(tokens.keys())
|
207 |
+
input_dynamic_axes = {k: build_shape_dict(k, v, True, seq_len) for k, v in tokens.items()}
|
208 |
+
|
209 |
+
# flatten potentially grouped outputs (past for gpt2, attentions)
|
210 |
+
outputs_flat = []
|
211 |
+
for output in outputs:
|
212 |
+
if isinstance(output, (tuple, list)):
|
213 |
+
outputs_flat.extend(output)
|
214 |
+
else:
|
215 |
+
outputs_flat.append(output)
|
216 |
+
|
217 |
+
# Generate output names & axes
|
218 |
+
output_names = [f"output_{i}" for i in range(len(outputs_flat))]
|
219 |
+
output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)}
|
220 |
+
|
221 |
+
# Create the aggregated axes representation
|
222 |
+
dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)
|
223 |
+
return input_vars, output_names, dynamic_axes, tokens
|
224 |
+
|
225 |
+
|
226 |
+
def load_graph_from_args(
|
227 |
+
pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None, **models_kwargs
|
228 |
+
) -> Pipeline:
|
229 |
+
"""
|
230 |
+
Convert the set of arguments provided through the CLI to an actual pipeline reference (tokenizer + model
|
231 |
+
|
232 |
+
Args:
|
233 |
+
pipeline_name: The kind of pipeline to use (ner, question-answering, etc.)
|
234 |
+
framework: The actual model to convert the pipeline from ("pt" or "tf")
|
235 |
+
model: The model name which will be loaded by the pipeline
|
236 |
+
tokenizer: The tokenizer name which will be loaded by the pipeline, default to the model's value
|
237 |
+
|
238 |
+
Returns: Pipeline object
|
239 |
+
|
240 |
+
"""
|
241 |
+
# If no tokenizer provided
|
242 |
+
if tokenizer is None:
|
243 |
+
tokenizer = model
|
244 |
+
|
245 |
+
# Check the wanted framework is available
|
246 |
+
if framework == "pt" and not is_torch_available():
|
247 |
+
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
|
248 |
+
if framework == "tf" and not is_tf_available():
|
249 |
+
raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
|
250 |
+
|
251 |
+
print(f"Loading pipeline (model: {model}, tokenizer: {tokenizer})")
|
252 |
+
|
253 |
+
# Allocate tokenizer and model
|
254 |
+
return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework, model_kwargs=models_kwargs)
|
255 |
+
|
256 |
+
|
257 |
+
def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format: bool):
|
258 |
+
"""
|
259 |
+
Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR
|
260 |
+
|
261 |
+
Args:
|
262 |
+
nlp: The pipeline to be exported
|
263 |
+
opset: The actual version of the ONNX operator set to use
|
264 |
+
output: Path where will be stored the generated ONNX model
|
265 |
+
use_external_format: Split the model definition from its parameters to allow model bigger than 2GB
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
|
269 |
+
"""
|
270 |
+
if not is_torch_available():
|
271 |
+
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
|
272 |
+
|
273 |
+
import torch
|
274 |
+
from torch.onnx import export
|
275 |
+
|
276 |
+
from transformers.pytorch_utils import is_torch_less_than_1_11
|
277 |
+
|
278 |
+
print(f"Using framework PyTorch: {torch.__version__}")
|
279 |
+
|
280 |
+
with torch.no_grad():
|
281 |
+
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
|
282 |
+
ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)
|
283 |
+
|
284 |
+
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
|
285 |
+
# so we check the torch version for backwards compatibility
|
286 |
+
if is_torch_less_than_1_11:
|
287 |
+
export(
|
288 |
+
nlp.model,
|
289 |
+
model_args,
|
290 |
+
f=output.as_posix(),
|
291 |
+
input_names=ordered_input_names,
|
292 |
+
output_names=output_names,
|
293 |
+
dynamic_axes=dynamic_axes,
|
294 |
+
do_constant_folding=True,
|
295 |
+
use_external_data_format=use_external_format,
|
296 |
+
enable_onnx_checker=True,
|
297 |
+
opset_version=opset,
|
298 |
+
)
|
299 |
+
else:
|
300 |
+
export(
|
301 |
+
nlp.model,
|
302 |
+
model_args,
|
303 |
+
f=output.as_posix(),
|
304 |
+
input_names=ordered_input_names,
|
305 |
+
output_names=output_names,
|
306 |
+
dynamic_axes=dynamic_axes,
|
307 |
+
do_constant_folding=True,
|
308 |
+
opset_version=opset,
|
309 |
+
)
|
310 |
+
|
311 |
+
|
312 |
+
def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
|
313 |
+
"""
|
314 |
+
Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR)
|
315 |
+
|
316 |
+
Args:
|
317 |
+
nlp: The pipeline to be exported
|
318 |
+
opset: The actual version of the ONNX operator set to use
|
319 |
+
output: Path where will be stored the generated ONNX model
|
320 |
+
|
321 |
+
Notes: TensorFlow cannot export model bigger than 2GB due to internal constraint from TensorFlow
|
322 |
+
|
323 |
+
"""
|
324 |
+
if not is_tf_available():
|
325 |
+
raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
|
326 |
+
|
327 |
+
print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\")
|
328 |
+
|
329 |
+
try:
|
330 |
+
import tensorflow as tf
|
331 |
+
import tf2onnx
|
332 |
+
from tf2onnx import __version__ as t2ov
|
333 |
+
|
334 |
+
print(f"Using framework TensorFlow: {tf.version.VERSION}, tf2onnx: {t2ov}")
|
335 |
+
|
336 |
+
# Build
|
337 |
+
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
|
338 |
+
|
339 |
+
# Forward
|
340 |
+
nlp.model.predict(tokens.data)
|
341 |
+
input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in tokens.items()]
|
342 |
+
model_proto, _ = tf2onnx.convert.from_keras(
|
343 |
+
nlp.model, input_signature, opset=opset, output_path=output.as_posix()
|
344 |
+
)
|
345 |
+
|
346 |
+
except ImportError as e:
|
347 |
+
raise Exception(
|
348 |
+
f"Cannot import {e.name} required to convert TF model to ONNX. Please install {e.name} first. {e}"
|
349 |
+
)
|
350 |
+
|
351 |
+
|
352 |
+
def convert(
|
353 |
+
framework: str,
|
354 |
+
model: str,
|
355 |
+
output: Path,
|
356 |
+
opset: int,
|
357 |
+
tokenizer: Optional[str] = None,
|
358 |
+
use_external_format: bool = False,
|
359 |
+
pipeline_name: str = "feature-extraction",
|
360 |
+
**model_kwargs,
|
361 |
+
):
|
362 |
+
"""
|
363 |
+
Convert the pipeline object to the ONNX Intermediate Representation (IR) format
|
364 |
+
|
365 |
+
Args:
|
366 |
+
framework: The framework the pipeline is backed by ("pt" or "tf")
|
367 |
+
model: The name of the model to load for the pipeline
|
368 |
+
output: The path where the ONNX graph will be stored
|
369 |
+
opset: The actual version of the ONNX operator set to use
|
370 |
+
tokenizer: The name of the model to load for the pipeline, default to the model's name if not provided
|
371 |
+
use_external_format:
|
372 |
+
Split the model definition from its parameters to allow model bigger than 2GB (PyTorch only)
|
373 |
+
pipeline_name: The kind of pipeline to instantiate (ner, question-answering, etc.)
|
374 |
+
model_kwargs: Keyword arguments to be forwarded to the model constructor
|
375 |
+
|
376 |
+
Returns:
|
377 |
+
|
378 |
+
"""
|
379 |
+
warnings.warn(
|
380 |
+
"The `transformers.convert_graph_to_onnx` package is deprecated and will be removed in version 5 of"
|
381 |
+
" Transformers",
|
382 |
+
FutureWarning,
|
383 |
+
)
|
384 |
+
print(f"ONNX opset version set to: {opset}")
|
385 |
+
|
386 |
+
# Load the pipeline
|
387 |
+
nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer, **model_kwargs)
|
388 |
+
|
389 |
+
if not output.parent.exists():
|
390 |
+
print(f"Creating folder {output.parent}")
|
391 |
+
makedirs(output.parent.as_posix())
|
392 |
+
elif len(listdir(output.parent.as_posix())) > 0:
|
393 |
+
raise Exception(f"Folder {output.parent.as_posix()} is not empty, aborting conversion")
|
394 |
+
|
395 |
+
# Export the graph
|
396 |
+
if framework == "pt":
|
397 |
+
convert_pytorch(nlp, opset, output, use_external_format)
|
398 |
+
else:
|
399 |
+
convert_tensorflow(nlp, opset, output)
|
400 |
+
|
401 |
+
|
402 |
+
def optimize(onnx_model_path: Path) -> Path:
|
403 |
+
"""
|
404 |
+
Load the model at the specified path and let onnxruntime look at transformations on the graph to enable all the
|
405 |
+
optimizations possible
|
406 |
+
|
407 |
+
Args:
|
408 |
+
onnx_model_path: filepath where the model binary description is stored
|
409 |
+
|
410 |
+
Returns: Path where the optimized model binary description has been saved
|
411 |
+
|
412 |
+
"""
|
413 |
+
from onnxruntime import InferenceSession, SessionOptions
|
414 |
+
|
415 |
+
# Generate model name with suffix "optimized"
|
416 |
+
opt_model_path = generate_identified_filename(onnx_model_path, "-optimized")
|
417 |
+
sess_option = SessionOptions()
|
418 |
+
sess_option.optimized_model_filepath = opt_model_path.as_posix()
|
419 |
+
_ = InferenceSession(onnx_model_path.as_posix(), sess_option)
|
420 |
+
|
421 |
+
print(f"Optimized model has been written at {opt_model_path}: \N{heavy check mark}")
|
422 |
+
print("/!\\ Optimized model contains hardware specific operators which might not be portable. /!\\")
|
423 |
+
|
424 |
+
return opt_model_path
|
425 |
+
|
426 |
+
|
427 |
+
def quantize(onnx_model_path: Path) -> Path:
|
428 |
+
"""
|
429 |
+
Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU
|
430 |
+
|
431 |
+
Args:
|
432 |
+
onnx_model_path: Path to location the exported ONNX model is stored
|
433 |
+
|
434 |
+
Returns: The Path generated for the quantized
|
435 |
+
"""
|
436 |
+
import onnx
|
437 |
+
import onnxruntime
|
438 |
+
from onnx.onnx_pb import ModelProto
|
439 |
+
from onnxruntime.quantization import QuantizationMode
|
440 |
+
from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
|
441 |
+
from onnxruntime.quantization.registry import IntegerOpsRegistry
|
442 |
+
|
443 |
+
# Load the ONNX model
|
444 |
+
onnx_model = onnx.load(onnx_model_path.as_posix())
|
445 |
+
|
446 |
+
if parse(onnx.__version__) < parse("1.5.0"):
|
447 |
+
print(
|
448 |
+
"Models larger than 2GB will fail to quantize due to protobuf constraint.\n"
|
449 |
+
"Please upgrade to onnxruntime >= 1.5.0."
|
450 |
+
)
|
451 |
+
|
452 |
+
# Copy it
|
453 |
+
copy_model = ModelProto()
|
454 |
+
copy_model.CopyFrom(onnx_model)
|
455 |
+
|
456 |
+
# Construct quantizer
|
457 |
+
# onnxruntime renamed input_qType to activation_qType in v1.13.1, so we
|
458 |
+
# check the onnxruntime version to ensure backward compatibility.
|
459 |
+
# See also: https://github.com/microsoft/onnxruntime/pull/12873
|
460 |
+
if parse(onnxruntime.__version__) < parse("1.13.1"):
|
461 |
+
quantizer = ONNXQuantizer(
|
462 |
+
model=copy_model,
|
463 |
+
per_channel=False,
|
464 |
+
reduce_range=False,
|
465 |
+
mode=QuantizationMode.IntegerOps,
|
466 |
+
static=False,
|
467 |
+
weight_qType=True,
|
468 |
+
input_qType=False,
|
469 |
+
tensors_range=None,
|
470 |
+
nodes_to_quantize=None,
|
471 |
+
nodes_to_exclude=None,
|
472 |
+
op_types_to_quantize=list(IntegerOpsRegistry),
|
473 |
+
)
|
474 |
+
else:
|
475 |
+
quantizer = ONNXQuantizer(
|
476 |
+
model=copy_model,
|
477 |
+
per_channel=False,
|
478 |
+
reduce_range=False,
|
479 |
+
mode=QuantizationMode.IntegerOps,
|
480 |
+
static=False,
|
481 |
+
weight_qType=True,
|
482 |
+
activation_qType=False,
|
483 |
+
tensors_range=None,
|
484 |
+
nodes_to_quantize=None,
|
485 |
+
nodes_to_exclude=None,
|
486 |
+
op_types_to_quantize=list(IntegerOpsRegistry),
|
487 |
+
)
|
488 |
+
|
489 |
+
# Quantize and export
|
490 |
+
quantizer.quantize_model()
|
491 |
+
|
492 |
+
# Append "-quantized" at the end of the model's name
|
493 |
+
quantized_model_path = generate_identified_filename(onnx_model_path, "-quantized")
|
494 |
+
|
495 |
+
# Save model
|
496 |
+
print(f"Quantized model has been written at {quantized_model_path}: \N{heavy check mark}")
|
497 |
+
onnx.save_model(quantizer.model.model, quantized_model_path.as_posix())
|
498 |
+
|
499 |
+
return quantized_model_path
|
500 |
+
|
501 |
+
|
502 |
+
def verify(path: Path):
|
503 |
+
from onnxruntime import InferenceSession, SessionOptions
|
504 |
+
from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException
|
505 |
+
|
506 |
+
print(f"Checking ONNX model loading from: {path} ...")
|
507 |
+
try:
|
508 |
+
onnx_options = SessionOptions()
|
509 |
+
_ = InferenceSession(path.as_posix(), onnx_options, providers=["CPUExecutionProvider"])
|
510 |
+
print(f"Model {path} correctly loaded: \N{heavy check mark}")
|
511 |
+
except RuntimeException as re:
|
512 |
+
print(f"Error while loading the model {re}: \N{heavy ballot x}")
|
513 |
+
|
514 |
+
|
515 |
+
if __name__ == "__main__":
|
516 |
+
parser = OnnxConverterArgumentParser()
|
517 |
+
args = parser.parse_args()
|
518 |
+
|
519 |
+
# Make sure output is absolute path
|
520 |
+
args.output = Path(args.output).absolute()
|
521 |
+
|
522 |
+
try:
|
523 |
+
print("\n====== Converting model to ONNX ======")
|
524 |
+
# Convert
|
525 |
+
convert(
|
526 |
+
args.framework,
|
527 |
+
args.model,
|
528 |
+
args.output,
|
529 |
+
args.opset,
|
530 |
+
args.tokenizer,
|
531 |
+
args.use_external_format,
|
532 |
+
args.pipeline,
|
533 |
+
)
|
534 |
+
|
535 |
+
if args.quantize:
|
536 |
+
# Ensure requirements for quantization on onnxruntime is met
|
537 |
+
check_onnxruntime_requirements(ORT_QUANTIZE_MINIMUM_VERSION)
|
538 |
+
|
539 |
+
# onnxruntime optimizations doesn't provide the same level of performances on TensorFlow than PyTorch
|
540 |
+
if args.framework == "tf":
|
541 |
+
print(
|
542 |
+
"\t Using TensorFlow might not provide the same optimization level compared to PyTorch.\n"
|
543 |
+
"\t For TensorFlow users you can try optimizing the model directly through onnxruntime_tools.\n"
|
544 |
+
"\t For more information, please refer to the onnxruntime documentation:\n"
|
545 |
+
"\t\thttps://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers\n"
|
546 |
+
)
|
547 |
+
|
548 |
+
print("\n====== Optimizing ONNX model ======")
|
549 |
+
|
550 |
+
# Quantization works best when using the optimized version of the model
|
551 |
+
args.optimized_output = optimize(args.output)
|
552 |
+
|
553 |
+
# Do the quantization on the right graph
|
554 |
+
args.quantized_output = quantize(args.optimized_output)
|
555 |
+
|
556 |
+
# And verify
|
557 |
+
if args.check_loading:
|
558 |
+
print("\n====== Check exported ONNX model(s) ======")
|
559 |
+
verify(args.output)
|
560 |
+
|
561 |
+
if hasattr(args, "optimized_output"):
|
562 |
+
verify(args.optimized_output)
|
563 |
+
|
564 |
+
if hasattr(args, "quantized_output"):
|
565 |
+
verify(args.quantized_output)
|
566 |
+
|
567 |
+
except Exception as e:
|
568 |
+
print(f"Error while converting the model: {e}")
|
569 |
+
exit(1)
|
transformers_4_35_0/convert_pytorch_checkpoint_to_tf2.py
ADDED
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Convert pytorch checkpoints to TensorFlow"""
|
16 |
+
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import os
|
20 |
+
|
21 |
+
from . import (
|
22 |
+
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
23 |
+
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
24 |
+
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
25 |
+
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
26 |
+
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
27 |
+
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
28 |
+
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
29 |
+
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
30 |
+
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
31 |
+
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
32 |
+
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
33 |
+
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
34 |
+
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
35 |
+
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
36 |
+
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
37 |
+
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
38 |
+
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
39 |
+
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
40 |
+
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
41 |
+
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
42 |
+
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
43 |
+
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
44 |
+
AlbertConfig,
|
45 |
+
BartConfig,
|
46 |
+
BertConfig,
|
47 |
+
CamembertConfig,
|
48 |
+
CTRLConfig,
|
49 |
+
DistilBertConfig,
|
50 |
+
DPRConfig,
|
51 |
+
ElectraConfig,
|
52 |
+
FlaubertConfig,
|
53 |
+
GPT2Config,
|
54 |
+
LayoutLMConfig,
|
55 |
+
LxmertConfig,
|
56 |
+
OpenAIGPTConfig,
|
57 |
+
RobertaConfig,
|
58 |
+
T5Config,
|
59 |
+
TFAlbertForPreTraining,
|
60 |
+
TFBartForConditionalGeneration,
|
61 |
+
TFBartForSequenceClassification,
|
62 |
+
TFBertForPreTraining,
|
63 |
+
TFBertForQuestionAnswering,
|
64 |
+
TFBertForSequenceClassification,
|
65 |
+
TFCamembertForMaskedLM,
|
66 |
+
TFCTRLLMHeadModel,
|
67 |
+
TFDistilBertForMaskedLM,
|
68 |
+
TFDistilBertForQuestionAnswering,
|
69 |
+
TFDPRContextEncoder,
|
70 |
+
TFDPRQuestionEncoder,
|
71 |
+
TFDPRReader,
|
72 |
+
TFElectraForPreTraining,
|
73 |
+
TFFlaubertWithLMHeadModel,
|
74 |
+
TFGPT2LMHeadModel,
|
75 |
+
TFLayoutLMForMaskedLM,
|
76 |
+
TFLxmertForPreTraining,
|
77 |
+
TFLxmertVisualFeatureEncoder,
|
78 |
+
TFOpenAIGPTLMHeadModel,
|
79 |
+
TFRobertaForCausalLM,
|
80 |
+
TFRobertaForMaskedLM,
|
81 |
+
TFRobertaForSequenceClassification,
|
82 |
+
TFT5ForConditionalGeneration,
|
83 |
+
TFTransfoXLLMHeadModel,
|
84 |
+
TFWav2Vec2Model,
|
85 |
+
TFXLMRobertaForMaskedLM,
|
86 |
+
TFXLMWithLMHeadModel,
|
87 |
+
TFXLNetLMHeadModel,
|
88 |
+
TransfoXLConfig,
|
89 |
+
Wav2Vec2Config,
|
90 |
+
Wav2Vec2Model,
|
91 |
+
XLMConfig,
|
92 |
+
XLMRobertaConfig,
|
93 |
+
XLNetConfig,
|
94 |
+
is_torch_available,
|
95 |
+
load_pytorch_checkpoint_in_tf2_model,
|
96 |
+
)
|
97 |
+
from .utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, logging
|
98 |
+
|
99 |
+
|
100 |
+
if is_torch_available():
|
101 |
+
import numpy as np
|
102 |
+
import torch
|
103 |
+
|
104 |
+
from . import (
|
105 |
+
AlbertForPreTraining,
|
106 |
+
BartForConditionalGeneration,
|
107 |
+
BertForPreTraining,
|
108 |
+
BertForQuestionAnswering,
|
109 |
+
BertForSequenceClassification,
|
110 |
+
CamembertForMaskedLM,
|
111 |
+
CTRLLMHeadModel,
|
112 |
+
DistilBertForMaskedLM,
|
113 |
+
DistilBertForQuestionAnswering,
|
114 |
+
DPRContextEncoder,
|
115 |
+
DPRQuestionEncoder,
|
116 |
+
DPRReader,
|
117 |
+
ElectraForPreTraining,
|
118 |
+
FlaubertWithLMHeadModel,
|
119 |
+
GPT2LMHeadModel,
|
120 |
+
LayoutLMForMaskedLM,
|
121 |
+
LxmertForPreTraining,
|
122 |
+
LxmertVisualFeatureEncoder,
|
123 |
+
OpenAIGPTLMHeadModel,
|
124 |
+
RobertaForMaskedLM,
|
125 |
+
RobertaForSequenceClassification,
|
126 |
+
T5ForConditionalGeneration,
|
127 |
+
TransfoXLLMHeadModel,
|
128 |
+
XLMRobertaForMaskedLM,
|
129 |
+
XLMWithLMHeadModel,
|
130 |
+
XLNetLMHeadModel,
|
131 |
+
)
|
132 |
+
|
133 |
+
|
134 |
+
logging.set_verbosity_info()
|
135 |
+
|
136 |
+
MODEL_CLASSES = {
|
137 |
+
"bart": (
|
138 |
+
BartConfig,
|
139 |
+
TFBartForConditionalGeneration,
|
140 |
+
TFBartForSequenceClassification,
|
141 |
+
BartForConditionalGeneration,
|
142 |
+
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
143 |
+
),
|
144 |
+
"bert": (
|
145 |
+
BertConfig,
|
146 |
+
TFBertForPreTraining,
|
147 |
+
BertForPreTraining,
|
148 |
+
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
149 |
+
),
|
150 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": (
|
151 |
+
BertConfig,
|
152 |
+
TFBertForQuestionAnswering,
|
153 |
+
BertForQuestionAnswering,
|
154 |
+
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
155 |
+
),
|
156 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": (
|
157 |
+
BertConfig,
|
158 |
+
TFBertForQuestionAnswering,
|
159 |
+
BertForQuestionAnswering,
|
160 |
+
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
161 |
+
),
|
162 |
+
"bert-base-cased-finetuned-mrpc": (
|
163 |
+
BertConfig,
|
164 |
+
TFBertForSequenceClassification,
|
165 |
+
BertForSequenceClassification,
|
166 |
+
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
167 |
+
),
|
168 |
+
"dpr": (
|
169 |
+
DPRConfig,
|
170 |
+
TFDPRQuestionEncoder,
|
171 |
+
TFDPRContextEncoder,
|
172 |
+
TFDPRReader,
|
173 |
+
DPRQuestionEncoder,
|
174 |
+
DPRContextEncoder,
|
175 |
+
DPRReader,
|
176 |
+
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
177 |
+
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
178 |
+
DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
179 |
+
),
|
180 |
+
"gpt2": (
|
181 |
+
GPT2Config,
|
182 |
+
TFGPT2LMHeadModel,
|
183 |
+
GPT2LMHeadModel,
|
184 |
+
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
185 |
+
),
|
186 |
+
"xlnet": (
|
187 |
+
XLNetConfig,
|
188 |
+
TFXLNetLMHeadModel,
|
189 |
+
XLNetLMHeadModel,
|
190 |
+
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
191 |
+
),
|
192 |
+
"xlm": (
|
193 |
+
XLMConfig,
|
194 |
+
TFXLMWithLMHeadModel,
|
195 |
+
XLMWithLMHeadModel,
|
196 |
+
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
197 |
+
),
|
198 |
+
"xlm-roberta": (
|
199 |
+
XLMRobertaConfig,
|
200 |
+
TFXLMRobertaForMaskedLM,
|
201 |
+
XLMRobertaForMaskedLM,
|
202 |
+
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
203 |
+
),
|
204 |
+
"transfo-xl": (
|
205 |
+
TransfoXLConfig,
|
206 |
+
TFTransfoXLLMHeadModel,
|
207 |
+
TransfoXLLMHeadModel,
|
208 |
+
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
209 |
+
),
|
210 |
+
"openai-gpt": (
|
211 |
+
OpenAIGPTConfig,
|
212 |
+
TFOpenAIGPTLMHeadModel,
|
213 |
+
OpenAIGPTLMHeadModel,
|
214 |
+
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
215 |
+
),
|
216 |
+
"roberta": (
|
217 |
+
RobertaConfig,
|
218 |
+
TFRobertaForCausalLM,
|
219 |
+
TFRobertaForMaskedLM,
|
220 |
+
RobertaForMaskedLM,
|
221 |
+
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
222 |
+
),
|
223 |
+
"layoutlm": (
|
224 |
+
LayoutLMConfig,
|
225 |
+
TFLayoutLMForMaskedLM,
|
226 |
+
LayoutLMForMaskedLM,
|
227 |
+
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
228 |
+
),
|
229 |
+
"roberta-large-mnli": (
|
230 |
+
RobertaConfig,
|
231 |
+
TFRobertaForSequenceClassification,
|
232 |
+
RobertaForSequenceClassification,
|
233 |
+
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
234 |
+
),
|
235 |
+
"camembert": (
|
236 |
+
CamembertConfig,
|
237 |
+
TFCamembertForMaskedLM,
|
238 |
+
CamembertForMaskedLM,
|
239 |
+
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
240 |
+
),
|
241 |
+
"flaubert": (
|
242 |
+
FlaubertConfig,
|
243 |
+
TFFlaubertWithLMHeadModel,
|
244 |
+
FlaubertWithLMHeadModel,
|
245 |
+
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
246 |
+
),
|
247 |
+
"distilbert": (
|
248 |
+
DistilBertConfig,
|
249 |
+
TFDistilBertForMaskedLM,
|
250 |
+
DistilBertForMaskedLM,
|
251 |
+
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
252 |
+
),
|
253 |
+
"distilbert-base-distilled-squad": (
|
254 |
+
DistilBertConfig,
|
255 |
+
TFDistilBertForQuestionAnswering,
|
256 |
+
DistilBertForQuestionAnswering,
|
257 |
+
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
258 |
+
),
|
259 |
+
"lxmert": (
|
260 |
+
LxmertConfig,
|
261 |
+
TFLxmertForPreTraining,
|
262 |
+
LxmertForPreTraining,
|
263 |
+
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
264 |
+
),
|
265 |
+
"lxmert-visual-feature-encoder": (
|
266 |
+
LxmertConfig,
|
267 |
+
TFLxmertVisualFeatureEncoder,
|
268 |
+
LxmertVisualFeatureEncoder,
|
269 |
+
LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
270 |
+
),
|
271 |
+
"ctrl": (
|
272 |
+
CTRLConfig,
|
273 |
+
TFCTRLLMHeadModel,
|
274 |
+
CTRLLMHeadModel,
|
275 |
+
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
276 |
+
),
|
277 |
+
"albert": (
|
278 |
+
AlbertConfig,
|
279 |
+
TFAlbertForPreTraining,
|
280 |
+
AlbertForPreTraining,
|
281 |
+
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
282 |
+
),
|
283 |
+
"t5": (
|
284 |
+
T5Config,
|
285 |
+
TFT5ForConditionalGeneration,
|
286 |
+
T5ForConditionalGeneration,
|
287 |
+
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
288 |
+
),
|
289 |
+
"electra": (
|
290 |
+
ElectraConfig,
|
291 |
+
TFElectraForPreTraining,
|
292 |
+
ElectraForPreTraining,
|
293 |
+
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
294 |
+
),
|
295 |
+
"wav2vec2": (
|
296 |
+
Wav2Vec2Config,
|
297 |
+
TFWav2Vec2Model,
|
298 |
+
Wav2Vec2Model,
|
299 |
+
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
300 |
+
),
|
301 |
+
}
|
302 |
+
|
303 |
+
|
304 |
+
def convert_pt_checkpoint_to_tf(
|
305 |
+
model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True
|
306 |
+
):
|
307 |
+
if model_type not in MODEL_CLASSES:
|
308 |
+
raise ValueError(f"Unrecognized model type, should be one of {list(MODEL_CLASSES.keys())}.")
|
309 |
+
|
310 |
+
config_class, model_class, pt_model_class, aws_config_map = MODEL_CLASSES[model_type]
|
311 |
+
|
312 |
+
# Initialise TF model
|
313 |
+
if config_file in aws_config_map:
|
314 |
+
config_file = cached_file(config_file, CONFIG_NAME, force_download=not use_cached_models)
|
315 |
+
config = config_class.from_json_file(config_file)
|
316 |
+
config.output_hidden_states = True
|
317 |
+
config.output_attentions = True
|
318 |
+
print(f"Building TensorFlow model from configuration: {config}")
|
319 |
+
tf_model = model_class(config)
|
320 |
+
|
321 |
+
# Load weights from tf checkpoint
|
322 |
+
if pytorch_checkpoint_path in aws_config_map.keys():
|
323 |
+
pytorch_checkpoint_path = cached_file(
|
324 |
+
pytorch_checkpoint_path, WEIGHTS_NAME, force_download=not use_cached_models
|
325 |
+
)
|
326 |
+
# Load PyTorch checkpoint in tf2 model:
|
327 |
+
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
328 |
+
|
329 |
+
if compare_with_pt_model:
|
330 |
+
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
|
331 |
+
|
332 |
+
state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
|
333 |
+
pt_model = pt_model_class.from_pretrained(
|
334 |
+
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
|
335 |
+
)
|
336 |
+
|
337 |
+
with torch.no_grad():
|
338 |
+
pto = pt_model(**pt_model.dummy_inputs)
|
339 |
+
|
340 |
+
np_pt = pto[0].numpy()
|
341 |
+
np_tf = tfo[0].numpy()
|
342 |
+
diff = np.amax(np.abs(np_pt - np_tf))
|
343 |
+
print(f"Max absolute difference between models outputs {diff}")
|
344 |
+
assert diff <= 2e-2, f"Error, model absolute difference is >2e-2: {diff}"
|
345 |
+
|
346 |
+
# Save pytorch-model
|
347 |
+
print(f"Save TensorFlow model to {tf_dump_path}")
|
348 |
+
tf_model.save_weights(tf_dump_path, save_format="h5")
|
349 |
+
|
350 |
+
|
351 |
+
def convert_all_pt_checkpoints_to_tf(
|
352 |
+
args_model_type,
|
353 |
+
tf_dump_path,
|
354 |
+
model_shortcut_names_or_path=None,
|
355 |
+
config_shortcut_names_or_path=None,
|
356 |
+
compare_with_pt_model=False,
|
357 |
+
use_cached_models=False,
|
358 |
+
remove_cached_files=False,
|
359 |
+
only_convert_finetuned_models=False,
|
360 |
+
):
|
361 |
+
if args_model_type is None:
|
362 |
+
model_types = list(MODEL_CLASSES.keys())
|
363 |
+
else:
|
364 |
+
model_types = [args_model_type]
|
365 |
+
|
366 |
+
for j, model_type in enumerate(model_types, start=1):
|
367 |
+
print("=" * 100)
|
368 |
+
print(f" Converting model type {j}/{len(model_types)}: {model_type}")
|
369 |
+
print("=" * 100)
|
370 |
+
if model_type not in MODEL_CLASSES:
|
371 |
+
raise ValueError(f"Unrecognized model type {model_type}, should be one of {list(MODEL_CLASSES.keys())}.")
|
372 |
+
|
373 |
+
config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
|
374 |
+
|
375 |
+
if model_shortcut_names_or_path is None:
|
376 |
+
model_shortcut_names_or_path = list(aws_model_maps.keys())
|
377 |
+
if config_shortcut_names_or_path is None:
|
378 |
+
config_shortcut_names_or_path = model_shortcut_names_or_path
|
379 |
+
|
380 |
+
for i, (model_shortcut_name, config_shortcut_name) in enumerate(
|
381 |
+
zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1
|
382 |
+
):
|
383 |
+
print("-" * 100)
|
384 |
+
if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name:
|
385 |
+
if not only_convert_finetuned_models:
|
386 |
+
print(f" Skipping finetuned checkpoint {model_shortcut_name}")
|
387 |
+
continue
|
388 |
+
model_type = model_shortcut_name
|
389 |
+
elif only_convert_finetuned_models:
|
390 |
+
print(f" Skipping not finetuned checkpoint {model_shortcut_name}")
|
391 |
+
continue
|
392 |
+
print(
|
393 |
+
f" Converting checkpoint {i}/{len(aws_config_map)}: {model_shortcut_name} - model_type {model_type}"
|
394 |
+
)
|
395 |
+
print("-" * 100)
|
396 |
+
|
397 |
+
if config_shortcut_name in aws_config_map:
|
398 |
+
config_file = cached_file(config_shortcut_name, CONFIG_NAME, force_download=not use_cached_models)
|
399 |
+
else:
|
400 |
+
config_file = config_shortcut_name
|
401 |
+
|
402 |
+
if model_shortcut_name in aws_model_maps:
|
403 |
+
model_file = cached_file(model_shortcut_name, WEIGHTS_NAME, force_download=not use_cached_models)
|
404 |
+
else:
|
405 |
+
model_file = model_shortcut_name
|
406 |
+
|
407 |
+
if os.path.isfile(model_shortcut_name):
|
408 |
+
model_shortcut_name = "converted_model"
|
409 |
+
|
410 |
+
convert_pt_checkpoint_to_tf(
|
411 |
+
model_type=model_type,
|
412 |
+
pytorch_checkpoint_path=model_file,
|
413 |
+
config_file=config_file,
|
414 |
+
tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"),
|
415 |
+
compare_with_pt_model=compare_with_pt_model,
|
416 |
+
)
|
417 |
+
if remove_cached_files:
|
418 |
+
os.remove(config_file)
|
419 |
+
os.remove(model_file)
|
420 |
+
|
421 |
+
|
422 |
+
if __name__ == "__main__":
|
423 |
+
parser = argparse.ArgumentParser()
|
424 |
+
# Required parameters
|
425 |
+
parser.add_argument(
|
426 |
+
"--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
|
427 |
+
)
|
428 |
+
parser.add_argument(
|
429 |
+
"--model_type",
|
430 |
+
default=None,
|
431 |
+
type=str,
|
432 |
+
help=(
|
433 |
+
f"Model type selected in the list of {list(MODEL_CLASSES.keys())}. If not given, will download and "
|
434 |
+
"convert all the models from AWS."
|
435 |
+
),
|
436 |
+
)
|
437 |
+
parser.add_argument(
|
438 |
+
"--pytorch_checkpoint_path",
|
439 |
+
default=None,
|
440 |
+
type=str,
|
441 |
+
help=(
|
442 |
+
"Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
|
443 |
+
"If not given, will download and convert all the checkpoints from AWS."
|
444 |
+
),
|
445 |
+
)
|
446 |
+
parser.add_argument(
|
447 |
+
"--config_file",
|
448 |
+
default=None,
|
449 |
+
type=str,
|
450 |
+
help=(
|
451 |
+
"The config json file corresponding to the pre-trained model. \n"
|
452 |
+
"This specifies the model architecture. If not given and "
|
453 |
+
"--pytorch_checkpoint_path is not given or is a shortcut name "
|
454 |
+
"use the configuration associated to the shortcut name on the AWS"
|
455 |
+
),
|
456 |
+
)
|
457 |
+
parser.add_argument(
|
458 |
+
"--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions."
|
459 |
+
)
|
460 |
+
parser.add_argument(
|
461 |
+
"--use_cached_models",
|
462 |
+
action="store_true",
|
463 |
+
help="Use cached models if possible instead of updating to latest checkpoint versions.",
|
464 |
+
)
|
465 |
+
parser.add_argument(
|
466 |
+
"--remove_cached_files",
|
467 |
+
action="store_true",
|
468 |
+
help="Remove pytorch models after conversion (save memory when converting in batches).",
|
469 |
+
)
|
470 |
+
parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.")
|
471 |
+
args = parser.parse_args()
|
472 |
+
|
473 |
+
# if args.pytorch_checkpoint_path is not None:
|
474 |
+
# convert_pt_checkpoint_to_tf(args.model_type.lower(),
|
475 |
+
# args.pytorch_checkpoint_path,
|
476 |
+
# args.config_file if args.config_file is not None else args.pytorch_checkpoint_path,
|
477 |
+
# args.tf_dump_path,
|
478 |
+
# compare_with_pt_model=args.compare_with_pt_model,
|
479 |
+
# use_cached_models=args.use_cached_models)
|
480 |
+
# else:
|
481 |
+
convert_all_pt_checkpoints_to_tf(
|
482 |
+
args.model_type.lower() if args.model_type is not None else None,
|
483 |
+
args.tf_dump_path,
|
484 |
+
model_shortcut_names_or_path=[args.pytorch_checkpoint_path]
|
485 |
+
if args.pytorch_checkpoint_path is not None
|
486 |
+
else None,
|
487 |
+
config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
|
488 |
+
compare_with_pt_model=args.compare_with_pt_model,
|
489 |
+
use_cached_models=args.use_cached_models,
|
490 |
+
remove_cached_files=args.remove_cached_files,
|
491 |
+
only_convert_finetuned_models=args.only_convert_finetuned_models,
|
492 |
+
)
|
transformers_4_35_0/convert_slow_tokenizer.py
ADDED
@@ -0,0 +1,1318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Utilities to convert slow tokenizers in their fast tokenizers counterparts.
|
17 |
+
|
18 |
+
All the conversions are grouped here to gather SentencePiece dependencies outside of the fast tokenizers files and
|
19 |
+
allow to make our dependency on SentencePiece optional.
|
20 |
+
"""
|
21 |
+
|
22 |
+
import warnings
|
23 |
+
from typing import Dict, List, Tuple
|
24 |
+
|
25 |
+
from packaging import version
|
26 |
+
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
27 |
+
from tokenizers.models import BPE, Unigram, WordPiece
|
28 |
+
|
29 |
+
from .utils import is_protobuf_available, requires_backends
|
30 |
+
from .utils.import_utils import PROTOBUF_IMPORT_ERROR
|
31 |
+
|
32 |
+
|
33 |
+
def import_protobuf(error_message=""):
|
34 |
+
if is_protobuf_available():
|
35 |
+
import google.protobuf
|
36 |
+
|
37 |
+
if version.parse(google.protobuf.__version__) < version.parse("4.0.0"):
|
38 |
+
from transformers.utils import sentencepiece_model_pb2
|
39 |
+
else:
|
40 |
+
from transformers.utils import sentencepiece_model_pb2_new as sentencepiece_model_pb2
|
41 |
+
return sentencepiece_model_pb2
|
42 |
+
else:
|
43 |
+
raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
|
44 |
+
|
45 |
+
|
46 |
+
class SentencePieceExtractor:
|
47 |
+
"""
|
48 |
+
Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self, model: str):
|
52 |
+
requires_backends(self, "sentencepiece")
|
53 |
+
from sentencepiece import SentencePieceProcessor
|
54 |
+
|
55 |
+
self.sp = SentencePieceProcessor()
|
56 |
+
self.sp.Load(model)
|
57 |
+
|
58 |
+
def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
|
59 |
+
"""
|
60 |
+
By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
|
61 |
+
order the merges with respect to the piece scores instead.
|
62 |
+
"""
|
63 |
+
sp = self.sp
|
64 |
+
vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
|
65 |
+
if vocab_scores is not None:
|
66 |
+
vocab_scores, reverse = dict(vocab_scores), True
|
67 |
+
else:
|
68 |
+
vocab_scores, reverse = vocab, False
|
69 |
+
|
70 |
+
# Merges
|
71 |
+
merges = []
|
72 |
+
for merge, piece_score in vocab_scores.items():
|
73 |
+
local = []
|
74 |
+
for index in range(1, len(merge)):
|
75 |
+
piece_l, piece_r = merge[:index], merge[index:]
|
76 |
+
if piece_l in vocab and piece_r in vocab:
|
77 |
+
local.append((piece_l, piece_r, piece_score))
|
78 |
+
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
|
79 |
+
merges.extend(local)
|
80 |
+
|
81 |
+
merges = sorted(merges, key=lambda val: val[2], reverse=reverse)
|
82 |
+
merges = [(val[0], val[1]) for val in merges]
|
83 |
+
return vocab, merges
|
84 |
+
|
85 |
+
|
86 |
+
def check_number_comma(piece: str) -> bool:
|
87 |
+
return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
|
88 |
+
|
89 |
+
|
90 |
+
class Converter:
|
91 |
+
def __init__(self, original_tokenizer):
|
92 |
+
self.original_tokenizer = original_tokenizer
|
93 |
+
|
94 |
+
def converted(self) -> Tokenizer:
|
95 |
+
raise NotImplementedError()
|
96 |
+
|
97 |
+
|
98 |
+
class BertConverter(Converter):
|
99 |
+
def converted(self) -> Tokenizer:
|
100 |
+
vocab = self.original_tokenizer.vocab
|
101 |
+
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
|
102 |
+
|
103 |
+
tokenize_chinese_chars = False
|
104 |
+
strip_accents = False
|
105 |
+
do_lower_case = False
|
106 |
+
if hasattr(self.original_tokenizer, "basic_tokenizer"):
|
107 |
+
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
|
108 |
+
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
|
109 |
+
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
|
110 |
+
|
111 |
+
tokenizer.normalizer = normalizers.BertNormalizer(
|
112 |
+
clean_text=True,
|
113 |
+
handle_chinese_chars=tokenize_chinese_chars,
|
114 |
+
strip_accents=strip_accents,
|
115 |
+
lowercase=do_lower_case,
|
116 |
+
)
|
117 |
+
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
118 |
+
|
119 |
+
cls = str(self.original_tokenizer.cls_token)
|
120 |
+
sep = str(self.original_tokenizer.sep_token)
|
121 |
+
cls_token_id = self.original_tokenizer.cls_token_id
|
122 |
+
sep_token_id = self.original_tokenizer.sep_token_id
|
123 |
+
|
124 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
125 |
+
single=f"{cls}:0 $A:0 {sep}:0",
|
126 |
+
pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
|
127 |
+
special_tokens=[
|
128 |
+
(cls, cls_token_id),
|
129 |
+
(sep, sep_token_id),
|
130 |
+
],
|
131 |
+
)
|
132 |
+
tokenizer.decoder = decoders.WordPiece(prefix="##")
|
133 |
+
|
134 |
+
return tokenizer
|
135 |
+
|
136 |
+
|
137 |
+
class SplinterConverter(Converter):
|
138 |
+
def converted(self) -> Tokenizer:
|
139 |
+
vocab = self.original_tokenizer.vocab
|
140 |
+
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
|
141 |
+
|
142 |
+
tokenize_chinese_chars = False
|
143 |
+
strip_accents = False
|
144 |
+
do_lower_case = False
|
145 |
+
if hasattr(self.original_tokenizer, "basic_tokenizer"):
|
146 |
+
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
|
147 |
+
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
|
148 |
+
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
|
149 |
+
|
150 |
+
tokenizer.normalizer = normalizers.BertNormalizer(
|
151 |
+
clean_text=True,
|
152 |
+
handle_chinese_chars=tokenize_chinese_chars,
|
153 |
+
strip_accents=strip_accents,
|
154 |
+
lowercase=do_lower_case,
|
155 |
+
)
|
156 |
+
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
157 |
+
|
158 |
+
cls = str(self.original_tokenizer.cls_token)
|
159 |
+
sep = str(self.original_tokenizer.sep_token)
|
160 |
+
question = str(self.original_tokenizer.question_token)
|
161 |
+
dot = "."
|
162 |
+
cls_token_id = self.original_tokenizer.cls_token_id
|
163 |
+
sep_token_id = self.original_tokenizer.sep_token_id
|
164 |
+
question_token_id = self.original_tokenizer.question_token_id
|
165 |
+
dot_token_id = self.original_tokenizer.convert_tokens_to_ids(".")
|
166 |
+
|
167 |
+
if self.original_tokenizer.padding_side == "right":
|
168 |
+
pair = f"{cls}:0 $A:0 {question} {dot} {sep}:0 $B:1 {sep}:1"
|
169 |
+
else:
|
170 |
+
pair = f"{cls}:0 $A:0 {sep}:0 $B:1 {question} {dot} {sep}:1"
|
171 |
+
|
172 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
173 |
+
single=f"{cls}:0 $A:0 {sep}:0",
|
174 |
+
pair=pair,
|
175 |
+
special_tokens=[
|
176 |
+
(cls, cls_token_id),
|
177 |
+
(sep, sep_token_id),
|
178 |
+
(question, question_token_id),
|
179 |
+
(dot, dot_token_id),
|
180 |
+
],
|
181 |
+
)
|
182 |
+
tokenizer.decoder = decoders.WordPiece(prefix="##")
|
183 |
+
|
184 |
+
return tokenizer
|
185 |
+
|
186 |
+
|
187 |
+
class FunnelConverter(Converter):
|
188 |
+
def converted(self) -> Tokenizer:
|
189 |
+
vocab = self.original_tokenizer.vocab
|
190 |
+
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
|
191 |
+
|
192 |
+
tokenize_chinese_chars = False
|
193 |
+
strip_accents = False
|
194 |
+
do_lower_case = False
|
195 |
+
if hasattr(self.original_tokenizer, "basic_tokenizer"):
|
196 |
+
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
|
197 |
+
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
|
198 |
+
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
|
199 |
+
|
200 |
+
tokenizer.normalizer = normalizers.BertNormalizer(
|
201 |
+
clean_text=True,
|
202 |
+
handle_chinese_chars=tokenize_chinese_chars,
|
203 |
+
strip_accents=strip_accents,
|
204 |
+
lowercase=do_lower_case,
|
205 |
+
)
|
206 |
+
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
207 |
+
|
208 |
+
cls = str(self.original_tokenizer.cls_token)
|
209 |
+
sep = str(self.original_tokenizer.sep_token)
|
210 |
+
cls_token_id = self.original_tokenizer.cls_token_id
|
211 |
+
sep_token_id = self.original_tokenizer.sep_token_id
|
212 |
+
|
213 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
214 |
+
single=f"{cls}:2 $A:0 {sep}:0", # token_type_id is 2 for Funnel transformer
|
215 |
+
pair=f"{cls}:2 $A:0 {sep}:0 $B:1 {sep}:1",
|
216 |
+
special_tokens=[
|
217 |
+
(cls, cls_token_id),
|
218 |
+
(sep, sep_token_id),
|
219 |
+
],
|
220 |
+
)
|
221 |
+
tokenizer.decoder = decoders.WordPiece(prefix="##")
|
222 |
+
|
223 |
+
return tokenizer
|
224 |
+
|
225 |
+
|
226 |
+
class MPNetConverter(Converter):
|
227 |
+
def converted(self) -> Tokenizer:
|
228 |
+
vocab = self.original_tokenizer.vocab
|
229 |
+
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
|
230 |
+
|
231 |
+
tokenize_chinese_chars = False
|
232 |
+
strip_accents = False
|
233 |
+
do_lower_case = False
|
234 |
+
if hasattr(self.original_tokenizer, "basic_tokenizer"):
|
235 |
+
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
|
236 |
+
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
|
237 |
+
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
|
238 |
+
|
239 |
+
tokenizer.normalizer = normalizers.BertNormalizer(
|
240 |
+
clean_text=True,
|
241 |
+
handle_chinese_chars=tokenize_chinese_chars,
|
242 |
+
strip_accents=strip_accents,
|
243 |
+
lowercase=do_lower_case,
|
244 |
+
)
|
245 |
+
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
246 |
+
|
247 |
+
cls = str(self.original_tokenizer.cls_token)
|
248 |
+
sep = str(self.original_tokenizer.sep_token)
|
249 |
+
cls_token_id = self.original_tokenizer.cls_token_id
|
250 |
+
sep_token_id = self.original_tokenizer.sep_token_id
|
251 |
+
|
252 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
253 |
+
single=f"{cls}:0 $A:0 {sep}:0",
|
254 |
+
pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1", # MPNet uses two [SEP] tokens
|
255 |
+
special_tokens=[
|
256 |
+
(cls, cls_token_id),
|
257 |
+
(sep, sep_token_id),
|
258 |
+
],
|
259 |
+
)
|
260 |
+
tokenizer.decoder = decoders.WordPiece(prefix="##")
|
261 |
+
|
262 |
+
return tokenizer
|
263 |
+
|
264 |
+
|
265 |
+
class OpenAIGPTConverter(Converter):
|
266 |
+
def converted(self) -> Tokenizer:
|
267 |
+
vocab = self.original_tokenizer.encoder
|
268 |
+
merges = list(self.original_tokenizer.bpe_ranks.keys())
|
269 |
+
unk_token = self.original_tokenizer.unk_token
|
270 |
+
|
271 |
+
tokenizer = Tokenizer(
|
272 |
+
BPE(
|
273 |
+
vocab=vocab,
|
274 |
+
merges=merges,
|
275 |
+
dropout=None,
|
276 |
+
unk_token=str(unk_token),
|
277 |
+
end_of_word_suffix="</w>",
|
278 |
+
fuse_unk=False,
|
279 |
+
)
|
280 |
+
)
|
281 |
+
|
282 |
+
if tokenizer.token_to_id(str(unk_token)) is not None:
|
283 |
+
tokenizer.add_special_tokens([str(unk_token)])
|
284 |
+
|
285 |
+
tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True)
|
286 |
+
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
287 |
+
tokenizer.decoder = decoders.BPEDecoder(suffix="</w>")
|
288 |
+
|
289 |
+
return tokenizer
|
290 |
+
|
291 |
+
|
292 |
+
class GPT2Converter(Converter):
|
293 |
+
def converted(self) -> Tokenizer:
|
294 |
+
vocab = self.original_tokenizer.encoder
|
295 |
+
merges = list(self.original_tokenizer.bpe_ranks.keys())
|
296 |
+
|
297 |
+
tokenizer = Tokenizer(
|
298 |
+
BPE(
|
299 |
+
vocab=vocab,
|
300 |
+
merges=merges,
|
301 |
+
dropout=None,
|
302 |
+
continuing_subword_prefix="",
|
303 |
+
end_of_word_suffix="",
|
304 |
+
fuse_unk=False,
|
305 |
+
)
|
306 |
+
)
|
307 |
+
|
308 |
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
|
309 |
+
tokenizer.decoder = decoders.ByteLevel()
|
310 |
+
if self.original_tokenizer.add_bos_token:
|
311 |
+
bos = self.original_tokenizer.bos_token
|
312 |
+
bos_token_id = self.original_tokenizer.bos_token_id
|
313 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
314 |
+
single=f"{bos}:0 $A:0",
|
315 |
+
pair=f"{bos}:0 $A:0 $B:1",
|
316 |
+
special_tokens=[
|
317 |
+
(bos, bos_token_id),
|
318 |
+
],
|
319 |
+
)
|
320 |
+
else:
|
321 |
+
# XXX trim_offsets=False actually means this post_processor doesn't
|
322 |
+
# really do anything.
|
323 |
+
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
|
324 |
+
return tokenizer
|
325 |
+
|
326 |
+
|
327 |
+
class HerbertConverter(Converter):
|
328 |
+
def converted(self) -> Tokenizer:
|
329 |
+
tokenizer_info_str = "#version:"
|
330 |
+
token_suffix = "</w>"
|
331 |
+
|
332 |
+
vocab = self.original_tokenizer.encoder
|
333 |
+
merges = list(self.original_tokenizer.bpe_ranks.keys())
|
334 |
+
if tokenizer_info_str in merges[0][0]:
|
335 |
+
merges = merges[1:]
|
336 |
+
|
337 |
+
tokenizer = Tokenizer(
|
338 |
+
BPE(
|
339 |
+
vocab,
|
340 |
+
merges,
|
341 |
+
dropout=None,
|
342 |
+
unk_token=self.original_tokenizer.unk_token,
|
343 |
+
end_of_word_suffix=token_suffix,
|
344 |
+
)
|
345 |
+
)
|
346 |
+
|
347 |
+
tokenizer.normalizer = normalizers.BertNormalizer(lowercase=False, strip_accents=False)
|
348 |
+
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
349 |
+
tokenizer.decoder = decoders.BPEDecoder(suffix=token_suffix)
|
350 |
+
tokenizer.post_processor = processors.BertProcessing(
|
351 |
+
sep=(self.original_tokenizer.sep_token, self.original_tokenizer.sep_token_id),
|
352 |
+
cls=(self.original_tokenizer.cls_token, self.original_tokenizer.cls_token_id),
|
353 |
+
)
|
354 |
+
|
355 |
+
return tokenizer
|
356 |
+
|
357 |
+
|
358 |
+
class RobertaConverter(Converter):
|
359 |
+
def converted(self) -> Tokenizer:
|
360 |
+
ot = self.original_tokenizer
|
361 |
+
vocab = ot.encoder
|
362 |
+
merges = list(ot.bpe_ranks.keys())
|
363 |
+
|
364 |
+
tokenizer = Tokenizer(
|
365 |
+
BPE(
|
366 |
+
vocab=vocab,
|
367 |
+
merges=merges,
|
368 |
+
dropout=None,
|
369 |
+
continuing_subword_prefix="",
|
370 |
+
end_of_word_suffix="",
|
371 |
+
fuse_unk=False,
|
372 |
+
)
|
373 |
+
)
|
374 |
+
|
375 |
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
|
376 |
+
tokenizer.decoder = decoders.ByteLevel()
|
377 |
+
tokenizer.post_processor = processors.RobertaProcessing(
|
378 |
+
sep=(ot.sep_token, ot.sep_token_id),
|
379 |
+
cls=(ot.cls_token, ot.cls_token_id),
|
380 |
+
add_prefix_space=ot.add_prefix_space,
|
381 |
+
trim_offsets=True, # True by default on Roberta (historical)
|
382 |
+
)
|
383 |
+
|
384 |
+
return tokenizer
|
385 |
+
|
386 |
+
|
387 |
+
class RoFormerConverter(Converter):
|
388 |
+
def converted(self) -> Tokenizer:
|
389 |
+
from .models.roformer.tokenization_utils import JiebaPreTokenizer
|
390 |
+
|
391 |
+
vocab = self.original_tokenizer.vocab
|
392 |
+
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
|
393 |
+
|
394 |
+
strip_accents = False
|
395 |
+
do_lower_case = False
|
396 |
+
if hasattr(self.original_tokenizer, "basic_tokenizer"):
|
397 |
+
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
|
398 |
+
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
|
399 |
+
|
400 |
+
tokenizer.normalizer = normalizers.BertNormalizer(
|
401 |
+
clean_text=True,
|
402 |
+
handle_chinese_chars=False,
|
403 |
+
strip_accents=strip_accents,
|
404 |
+
lowercase=do_lower_case,
|
405 |
+
)
|
406 |
+
tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JiebaPreTokenizer(vocab))
|
407 |
+
|
408 |
+
cls = str(self.original_tokenizer.cls_token)
|
409 |
+
sep = str(self.original_tokenizer.sep_token)
|
410 |
+
cls_token_id = self.original_tokenizer.cls_token_id
|
411 |
+
sep_token_id = self.original_tokenizer.sep_token_id
|
412 |
+
|
413 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
414 |
+
single=f"{cls}:0 $A:0 {sep}:0",
|
415 |
+
pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
|
416 |
+
special_tokens=[
|
417 |
+
(cls, cls_token_id),
|
418 |
+
(sep, sep_token_id),
|
419 |
+
],
|
420 |
+
)
|
421 |
+
tokenizer.decoder = decoders.WordPiece(prefix="##")
|
422 |
+
|
423 |
+
return tokenizer
|
424 |
+
|
425 |
+
|
426 |
+
class DebertaConverter(Converter):
|
427 |
+
def converted(self) -> Tokenizer:
|
428 |
+
ot = self.original_tokenizer
|
429 |
+
vocab = ot.encoder
|
430 |
+
merges = list(ot.bpe_ranks.keys())
|
431 |
+
|
432 |
+
tokenizer = Tokenizer(
|
433 |
+
BPE(
|
434 |
+
vocab=vocab,
|
435 |
+
merges=merges,
|
436 |
+
dropout=None,
|
437 |
+
continuing_subword_prefix="",
|
438 |
+
end_of_word_suffix="",
|
439 |
+
fuse_unk=False,
|
440 |
+
)
|
441 |
+
)
|
442 |
+
|
443 |
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
|
444 |
+
tokenizer.decoder = decoders.ByteLevel()
|
445 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
446 |
+
single="[CLS]:0 $A:0 [SEP]:0",
|
447 |
+
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
|
448 |
+
special_tokens=[
|
449 |
+
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
|
450 |
+
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
|
451 |
+
],
|
452 |
+
)
|
453 |
+
|
454 |
+
return tokenizer
|
455 |
+
|
456 |
+
|
457 |
+
class SpmConverter(Converter):
|
458 |
+
def __init__(self, *args):
|
459 |
+
requires_backends(self, "protobuf")
|
460 |
+
|
461 |
+
super().__init__(*args)
|
462 |
+
|
463 |
+
# from .utils import sentencepiece_model_pb2 as model_pb2
|
464 |
+
model_pb2 = import_protobuf()
|
465 |
+
|
466 |
+
m = model_pb2.ModelProto()
|
467 |
+
with open(self.original_tokenizer.vocab_file, "rb") as f:
|
468 |
+
m.ParseFromString(f.read())
|
469 |
+
self.proto = m
|
470 |
+
|
471 |
+
if self.proto.trainer_spec.byte_fallback:
|
472 |
+
if not getattr(self, "handle_byte_fallback", None):
|
473 |
+
warnings.warn(
|
474 |
+
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
|
475 |
+
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
|
476 |
+
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
|
477 |
+
"unknown tokens into a sequence of byte tokens matching the original piece of text."
|
478 |
+
)
|
479 |
+
|
480 |
+
def vocab(self, proto):
|
481 |
+
return [(piece.piece, piece.score) for piece in proto.pieces]
|
482 |
+
|
483 |
+
def unk_id(self, proto):
|
484 |
+
return proto.trainer_spec.unk_id
|
485 |
+
|
486 |
+
def tokenizer(self, proto):
|
487 |
+
model_type = proto.trainer_spec.model_type
|
488 |
+
vocab_scores = self.vocab(proto)
|
489 |
+
unk_id = self.unk_id(proto)
|
490 |
+
|
491 |
+
if model_type == 1:
|
492 |
+
tokenizer = Tokenizer(Unigram(vocab_scores, unk_id))
|
493 |
+
elif model_type == 2:
|
494 |
+
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
|
495 |
+
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
|
496 |
+
tokenizer = Tokenizer(
|
497 |
+
BPE(
|
498 |
+
bpe_vocab,
|
499 |
+
merges,
|
500 |
+
unk_token=proto.trainer_spec.unk_piece,
|
501 |
+
fuse_unk=True,
|
502 |
+
)
|
503 |
+
)
|
504 |
+
else:
|
505 |
+
raise Exception(
|
506 |
+
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
507 |
+
)
|
508 |
+
|
509 |
+
return tokenizer
|
510 |
+
|
511 |
+
def normalizer(self, proto):
|
512 |
+
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
|
513 |
+
if not precompiled_charsmap:
|
514 |
+
return normalizers.Sequence([normalizers.Replace(Regex(" {2,}"), " ")])
|
515 |
+
else:
|
516 |
+
return normalizers.Sequence(
|
517 |
+
[normalizers.Precompiled(precompiled_charsmap), normalizers.Replace(Regex(" {2,}"), " ")]
|
518 |
+
)
|
519 |
+
|
520 |
+
def pre_tokenizer(self, replacement, add_prefix_space):
|
521 |
+
return pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
|
522 |
+
|
523 |
+
def post_processor(self):
|
524 |
+
return None
|
525 |
+
|
526 |
+
def decoder(self, replacement, add_prefix_space):
|
527 |
+
return decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
|
528 |
+
|
529 |
+
def converted(self) -> Tokenizer:
|
530 |
+
tokenizer = self.tokenizer(self.proto)
|
531 |
+
|
532 |
+
# Tokenizer assemble
|
533 |
+
normalizer = self.normalizer(self.proto)
|
534 |
+
if normalizer is not None:
|
535 |
+
tokenizer.normalizer = normalizer
|
536 |
+
|
537 |
+
replacement = "▁"
|
538 |
+
add_prefix_space = True
|
539 |
+
pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
|
540 |
+
if pre_tokenizer is not None:
|
541 |
+
tokenizer.pre_tokenizer = pre_tokenizer
|
542 |
+
|
543 |
+
tokenizer.decoder = self.decoder(replacement, add_prefix_space)
|
544 |
+
post_processor = self.post_processor()
|
545 |
+
if post_processor:
|
546 |
+
tokenizer.post_processor = post_processor
|
547 |
+
|
548 |
+
return tokenizer
|
549 |
+
|
550 |
+
|
551 |
+
class AlbertConverter(SpmConverter):
|
552 |
+
def vocab(self, proto):
|
553 |
+
return [
|
554 |
+
(piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
|
555 |
+
for piece in proto.pieces
|
556 |
+
]
|
557 |
+
|
558 |
+
def normalizer(self, proto):
|
559 |
+
list_normalizers = [
|
560 |
+
normalizers.Replace("``", '"'),
|
561 |
+
normalizers.Replace("''", '"'),
|
562 |
+
]
|
563 |
+
if not self.original_tokenizer.keep_accents:
|
564 |
+
list_normalizers.append(normalizers.NFKD())
|
565 |
+
list_normalizers.append(normalizers.StripAccents())
|
566 |
+
if self.original_tokenizer.do_lower_case:
|
567 |
+
list_normalizers.append(normalizers.Lowercase())
|
568 |
+
|
569 |
+
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
|
570 |
+
|
571 |
+
if precompiled_charsmap:
|
572 |
+
list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
|
573 |
+
|
574 |
+
list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
|
575 |
+
return normalizers.Sequence(list_normalizers)
|
576 |
+
|
577 |
+
def post_processor(self):
|
578 |
+
return processors.TemplateProcessing(
|
579 |
+
single="[CLS]:0 $A:0 [SEP]:0",
|
580 |
+
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
|
581 |
+
special_tokens=[
|
582 |
+
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
|
583 |
+
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
|
584 |
+
],
|
585 |
+
)
|
586 |
+
|
587 |
+
|
588 |
+
class BarthezConverter(SpmConverter):
|
589 |
+
def unk_id(self, proto):
|
590 |
+
unk_id = 3
|
591 |
+
return unk_id
|
592 |
+
|
593 |
+
def post_processor(self):
|
594 |
+
return processors.TemplateProcessing(
|
595 |
+
single="<s> $A </s>",
|
596 |
+
pair="<s> $A </s> </s> $B </s>",
|
597 |
+
special_tokens=[
|
598 |
+
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
|
599 |
+
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
600 |
+
],
|
601 |
+
)
|
602 |
+
|
603 |
+
|
604 |
+
class CamembertConverter(SpmConverter):
|
605 |
+
def vocab(self, proto):
|
606 |
+
vocab = [
|
607 |
+
("<s>NOTUSED", 0.0),
|
608 |
+
("<pad>", 0.0),
|
609 |
+
("</s>NOTUSED", 0.0),
|
610 |
+
("<unk>", 0.0),
|
611 |
+
("<unk>NOTUSED", -100),
|
612 |
+
]
|
613 |
+
# We down-grade the original SentencePiece by -100 to avoid using it and use our added token instead
|
614 |
+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[1:]]
|
615 |
+
vocab += [("<mask>", 0.0)]
|
616 |
+
return vocab
|
617 |
+
|
618 |
+
def unk_id(self, proto):
|
619 |
+
# See vocab unk position
|
620 |
+
return 3
|
621 |
+
|
622 |
+
def post_processor(self):
|
623 |
+
return processors.TemplateProcessing(
|
624 |
+
single="<s> $A </s>",
|
625 |
+
pair="<s> $A </s> </s> $B </s>",
|
626 |
+
special_tokens=[
|
627 |
+
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
|
628 |
+
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
629 |
+
],
|
630 |
+
)
|
631 |
+
|
632 |
+
|
633 |
+
class DebertaV2Converter(SpmConverter):
|
634 |
+
def pre_tokenizer(self, replacement, add_prefix_space):
|
635 |
+
list_pretokenizers = []
|
636 |
+
if self.original_tokenizer.split_by_punct:
|
637 |
+
list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
|
638 |
+
list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space))
|
639 |
+
return pre_tokenizers.Sequence(list_pretokenizers)
|
640 |
+
|
641 |
+
def normalizer(self, proto):
|
642 |
+
list_normalizers = []
|
643 |
+
if self.original_tokenizer.do_lower_case:
|
644 |
+
list_normalizers.append(normalizers.Lowercase())
|
645 |
+
list_normalizers.append(normalizers.Strip())
|
646 |
+
|
647 |
+
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
|
648 |
+
if precompiled_charsmap:
|
649 |
+
list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
|
650 |
+
list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
|
651 |
+
|
652 |
+
return normalizers.Sequence(list_normalizers)
|
653 |
+
|
654 |
+
def post_processor(self):
|
655 |
+
return processors.TemplateProcessing(
|
656 |
+
single="[CLS]:0 $A:0 [SEP]:0",
|
657 |
+
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
|
658 |
+
special_tokens=[
|
659 |
+
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
|
660 |
+
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
|
661 |
+
],
|
662 |
+
)
|
663 |
+
|
664 |
+
|
665 |
+
class MBartConverter(SpmConverter):
|
666 |
+
def vocab(self, proto):
|
667 |
+
vocab = [
|
668 |
+
("<s>", 0.0),
|
669 |
+
("<pad>", 0.0),
|
670 |
+
("</s>", 0.0),
|
671 |
+
("<unk>", 0.0),
|
672 |
+
]
|
673 |
+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
674 |
+
vocab += [
|
675 |
+
("ar_AR", 0.0),
|
676 |
+
("cs_CZ", 0.0),
|
677 |
+
("de_DE", 0.0),
|
678 |
+
("en_XX", 0.0),
|
679 |
+
("es_XX", 0.0),
|
680 |
+
("et_EE", 0.0),
|
681 |
+
("fi_FI", 0.0),
|
682 |
+
("fr_XX", 0.0),
|
683 |
+
("gu_IN", 0.0),
|
684 |
+
("hi_IN", 0.0),
|
685 |
+
("it_IT", 0.0),
|
686 |
+
("ja_XX", 0.0),
|
687 |
+
("kk_KZ", 0.0),
|
688 |
+
("ko_KR", 0.0),
|
689 |
+
("lt_LT", 0.0),
|
690 |
+
("lv_LV", 0.0),
|
691 |
+
("my_MM", 0.0),
|
692 |
+
("ne_NP", 0.0),
|
693 |
+
("nl_XX", 0.0),
|
694 |
+
("ro_RO", 0.0),
|
695 |
+
("ru_RU", 0.0),
|
696 |
+
("si_LK", 0.0),
|
697 |
+
("tr_TR", 0.0),
|
698 |
+
("vi_VN", 0.0),
|
699 |
+
("zh_CN", 0.0),
|
700 |
+
]
|
701 |
+
vocab += [("<mask>", 0.0)]
|
702 |
+
return vocab
|
703 |
+
|
704 |
+
def unk_id(self, proto):
|
705 |
+
return 3
|
706 |
+
|
707 |
+
def post_processor(self):
|
708 |
+
return processors.TemplateProcessing(
|
709 |
+
single="$A </s> en_XX",
|
710 |
+
pair="$A $B </s> en_XX",
|
711 |
+
special_tokens=[
|
712 |
+
("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
|
713 |
+
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
714 |
+
],
|
715 |
+
)
|
716 |
+
|
717 |
+
|
718 |
+
class MBart50Converter(SpmConverter):
|
719 |
+
def vocab(self, proto):
|
720 |
+
vocab = [
|
721 |
+
("<s>", 0.0),
|
722 |
+
("<pad>", 0.0),
|
723 |
+
("</s>", 0.0),
|
724 |
+
("<unk>", 0.0),
|
725 |
+
]
|
726 |
+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
727 |
+
# fmt: off
|
728 |
+
vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)]
|
729 |
+
# fmt: on
|
730 |
+
vocab += [("<mask>", 0.0)]
|
731 |
+
return vocab
|
732 |
+
|
733 |
+
def unk_id(self, proto):
|
734 |
+
return 3
|
735 |
+
|
736 |
+
def post_processor(self):
|
737 |
+
return processors.TemplateProcessing(
|
738 |
+
single="en_XX $A </s>",
|
739 |
+
pair="en_XX $A $B </s>",
|
740 |
+
special_tokens=[
|
741 |
+
("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
|
742 |
+
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
743 |
+
],
|
744 |
+
)
|
745 |
+
|
746 |
+
|
747 |
+
class NllbConverter(SpmConverter):
|
748 |
+
def vocab(self, proto):
|
749 |
+
vocab = [
|
750 |
+
("<s>", 0.0),
|
751 |
+
("<pad>", 0.0),
|
752 |
+
("</s>", 0.0),
|
753 |
+
("<unk>", 0.0),
|
754 |
+
]
|
755 |
+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
756 |
+
vocab += [
|
757 |
+
# fmt: off
|
758 |
+
('ace_Arab', 0.0), ('ace_Latn', 0.0), ('acm_Arab', 0.0), ('acq_Arab', 0.0), ('aeb_Arab', 0.0), ('afr_Latn', 0.0), ('ajp_Arab', 0.0), ('aka_Latn', 0.0), ('amh_Ethi', 0.0), ('apc_Arab', 0.0), ('arb_Arab', 0.0), ('ars_Arab', 0.0), ('ary_Arab', 0.0), ('arz_Arab', 0.0), ('asm_Beng', 0.0), ('ast_Latn', 0.0), ('awa_Deva', 0.0), ('ayr_Latn', 0.0), ('azb_Arab', 0.0), ('azj_Latn', 0.0), ('bak_Cyrl', 0.0), ('bam_Latn', 0.0), ('ban_Latn', 0.0), ('bel_Cyrl', 0.0), ('bem_Latn', 0.0), ('ben_Beng', 0.0), ('bho_Deva', 0.0), ('bjn_Arab', 0.0), ('bjn_Latn', 0.0), ('bod_Tibt', 0.0), ('bos_Latn', 0.0), ('bug_Latn', 0.0), ('bul_Cyrl', 0.0), ('cat_Latn', 0.0), ('ceb_Latn', 0.0), ('ces_Latn', 0.0), ('cjk_Latn', 0.0), ('ckb_Arab', 0.0), ('crh_Latn', 0.0), ('cym_Latn', 0.0), ('dan_Latn', 0.0), ('deu_Latn', 0.0), ('dik_Latn', 0.0), ('dyu_Latn', 0.0), ('dzo_Tibt', 0.0), ('ell_Grek', 0.0), ('eng_Latn', 0.0), ('epo_Latn', 0.0), ('est_Latn', 0.0), ('eus_Latn', 0.0), ('ewe_Latn', 0.0), ('fao_Latn', 0.0), ('pes_Arab', 0.0), ('fij_Latn', 0.0), ('fin_Latn', 0.0), ('fon_Latn', 0.0), ('fra_Latn', 0.0), ('fur_Latn', 0.0), ('fuv_Latn', 0.0), ('gla_Latn', 0.0), ('gle_Latn', 0.0), ('glg_Latn', 0.0), ('grn_Latn', 0.0), ('guj_Gujr', 0.0), ('hat_Latn', 0.0), ('hau_Latn', 0.0), ('heb_Hebr', 0.0), ('hin_Deva', 0.0), ('hne_Deva', 0.0), ('hrv_Latn', 0.0), ('hun_Latn', 0.0), ('hye_Armn', 0.0), ('ibo_Latn', 0.0), ('ilo_Latn', 0.0), ('ind_Latn', 0.0), ('isl_Latn', 0.0), ('ita_Latn', 0.0), ('jav_Latn', 0.0), ('jpn_Jpan', 0.0), ('kab_Latn', 0.0), ('kac_Latn', 0.0), ('kam_Latn', 0.0), ('kan_Knda', 0.0), ('kas_Arab', 0.0), ('kas_Deva', 0.0), ('kat_Geor', 0.0), ('knc_Arab', 0.0), ('knc_Latn', 0.0), ('kaz_Cyrl', 0.0), ('kbp_Latn', 0.0), ('kea_Latn', 0.0), ('khm_Khmr', 0.0), ('kik_Latn', 0.0), ('kin_Latn', 0.0), ('kir_Cyrl', 0.0), ('kmb_Latn', 0.0), ('kon_Latn', 0.0), ('kor_Hang', 0.0), ('kmr_Latn', 0.0), ('lao_Laoo', 0.0), ('lvs_Latn', 0.0), ('lij_Latn', 0.0), ('lim_Latn', 0.0), ('lin_Latn', 0.0), ('lit_Latn', 0.0), ('lmo_Latn', 0.0), ('ltg_Latn', 0.0), ('ltz_Latn', 0.0), ('lua_Latn', 0.0), ('lug_Latn', 0.0), ('luo_Latn', 0.0), ('lus_Latn', 0.0), ('mag_Deva', 0.0), ('mai_Deva', 0.0), ('mal_Mlym', 0.0), ('mar_Deva', 0.0), ('min_Latn', 0.0), ('mkd_Cyrl', 0.0), ('plt_Latn', 0.0), ('mlt_Latn', 0.0), ('mni_Beng', 0.0), ('khk_Cyrl', 0.0), ('mos_Latn', 0.0), ('mri_Latn', 0.0), ('zsm_Latn', 0.0), ('mya_Mymr', 0.0), ('nld_Latn', 0.0), ('nno_Latn', 0.0), ('nob_Latn', 0.0), ('npi_Deva', 0.0), ('nso_Latn', 0.0), ('nus_Latn', 0.0), ('nya_Latn', 0.0), ('oci_Latn', 0.0), ('gaz_Latn', 0.0), ('ory_Orya', 0.0), ('pag_Latn', 0.0), ('pan_Guru', 0.0), ('pap_Latn', 0.0), ('pol_Latn', 0.0), ('por_Latn', 0.0), ('prs_Arab', 0.0), ('pbt_Arab', 0.0), ('quy_Latn', 0.0), ('ron_Latn', 0.0), ('run_Latn', 0.0), ('rus_Cyrl', 0.0), ('sag_Latn', 0.0), ('san_Deva', 0.0), ('sat_Beng', 0.0), ('scn_Latn', 0.0), ('shn_Mymr', 0.0), ('sin_Sinh', 0.0), ('slk_Latn', 0.0), ('slv_Latn', 0.0), ('smo_Latn', 0.0), ('sna_Latn', 0.0), ('snd_Arab', 0.0), ('som_Latn', 0.0), ('sot_Latn', 0.0), ('spa_Latn', 0.0), ('als_Latn', 0.0), ('srd_Latn', 0.0), ('srp_Cyrl', 0.0), ('ssw_Latn', 0.0), ('sun_Latn', 0.0), ('swe_Latn', 0.0), ('swh_Latn', 0.0), ('szl_Latn', 0.0), ('tam_Taml', 0.0), ('tat_Cyrl', 0.0), ('tel_Telu', 0.0), ('tgk_Cyrl', 0.0), ('tgl_Latn', 0.0), ('tha_Thai', 0.0), ('tir_Ethi', 0.0), ('taq_Latn', 0.0), ('taq_Tfng', 0.0), ('tpi_Latn', 0.0), ('tsn_Latn', 0.0), ('tso_Latn', 0.0), ('tuk_Latn', 0.0), ('tum_Latn', 0.0), ('tur_Latn', 0.0), ('twi_Latn', 0.0), ('tzm_Tfng', 0.0), ('uig_Arab', 0.0), ('ukr_Cyrl', 0.0), ('umb_Latn', 0.0), ('urd_Arab', 0.0), ('uzn_Latn', 0.0), ('vec_Latn', 0.0), ('vie_Latn', 0.0), ('war_Latn', 0.0), ('wol_Latn', 0.0), ('xho_Latn', 0.0), ('ydd_Hebr', 0.0), ('yor_Latn', 0.0), ('yue_Hant', 0.0), ('zho_Hans', 0.0), ('zho_Hant', 0.0), ('zul_Latn', 0.0)
|
759 |
+
# fmt: on
|
760 |
+
]
|
761 |
+
vocab += [("<mask>", 0.0)]
|
762 |
+
return vocab
|
763 |
+
|
764 |
+
def unk_id(self, proto):
|
765 |
+
return 3
|
766 |
+
|
767 |
+
def post_processor(self):
|
768 |
+
return processors.TemplateProcessing(
|
769 |
+
single="eng_Latn $A </s>",
|
770 |
+
pair="eng_Latn $A $B </s>",
|
771 |
+
special_tokens=[
|
772 |
+
("eng_Latn", self.original_tokenizer.convert_tokens_to_ids("eng_Latn")),
|
773 |
+
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
774 |
+
],
|
775 |
+
)
|
776 |
+
|
777 |
+
|
778 |
+
class XLMRobertaConverter(SpmConverter):
|
779 |
+
def vocab(self, proto):
|
780 |
+
vocab = [
|
781 |
+
("<s>", 0.0),
|
782 |
+
("<pad>", 0.0),
|
783 |
+
("</s>", 0.0),
|
784 |
+
("<unk>", 0.0),
|
785 |
+
]
|
786 |
+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
787 |
+
vocab += [("<mask>", 0.0)]
|
788 |
+
return vocab
|
789 |
+
|
790 |
+
def unk_id(self, proto):
|
791 |
+
unk_id = 3
|
792 |
+
return unk_id
|
793 |
+
|
794 |
+
def post_processor(self):
|
795 |
+
return processors.TemplateProcessing(
|
796 |
+
single="<s> $A </s>",
|
797 |
+
pair="<s> $A </s> </s> $B </s>",
|
798 |
+
special_tokens=[
|
799 |
+
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
|
800 |
+
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
801 |
+
],
|
802 |
+
)
|
803 |
+
|
804 |
+
|
805 |
+
class XLNetConverter(SpmConverter):
|
806 |
+
def vocab(self, proto):
|
807 |
+
return [
|
808 |
+
(piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
|
809 |
+
for piece in proto.pieces
|
810 |
+
]
|
811 |
+
|
812 |
+
def normalizer(self, proto):
|
813 |
+
list_normalizers = [
|
814 |
+
normalizers.Replace("``", '"'),
|
815 |
+
normalizers.Replace("''", '"'),
|
816 |
+
]
|
817 |
+
if not self.original_tokenizer.keep_accents:
|
818 |
+
list_normalizers.append(normalizers.NFKD())
|
819 |
+
list_normalizers.append(normalizers.StripAccents())
|
820 |
+
if self.original_tokenizer.do_lower_case:
|
821 |
+
list_normalizers.append(normalizers.Lowercase())
|
822 |
+
|
823 |
+
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
|
824 |
+
|
825 |
+
if precompiled_charsmap:
|
826 |
+
list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
|
827 |
+
|
828 |
+
list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
|
829 |
+
return normalizers.Sequence(list_normalizers)
|
830 |
+
|
831 |
+
def post_processor(self):
|
832 |
+
return processors.TemplateProcessing(
|
833 |
+
single="$A:0 <sep>:0 <cls>:2",
|
834 |
+
pair="$A:0 <sep>:0 $B:1 <sep>:1 <cls>:2",
|
835 |
+
special_tokens=[
|
836 |
+
("<sep>", self.original_tokenizer.convert_tokens_to_ids("<sep>")),
|
837 |
+
("<cls>", self.original_tokenizer.convert_tokens_to_ids("<cls>")),
|
838 |
+
],
|
839 |
+
)
|
840 |
+
|
841 |
+
|
842 |
+
class ReformerConverter(SpmConverter):
|
843 |
+
pass
|
844 |
+
|
845 |
+
|
846 |
+
class RemBertConverter(SpmConverter):
|
847 |
+
# Inspired from AlbertConverter
|
848 |
+
def normalizer(self, proto):
|
849 |
+
list_normalizers = [
|
850 |
+
normalizers.Replace("``", '"'),
|
851 |
+
normalizers.Replace("''", '"'),
|
852 |
+
normalizers.Replace(Regex(" {2,}"), " "),
|
853 |
+
]
|
854 |
+
if not self.original_tokenizer.keep_accents:
|
855 |
+
list_normalizers.append(normalizers.NFKD())
|
856 |
+
list_normalizers.append(normalizers.StripAccents())
|
857 |
+
if self.original_tokenizer.do_lower_case:
|
858 |
+
list_normalizers.append(normalizers.Lowercase())
|
859 |
+
|
860 |
+
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
|
861 |
+
|
862 |
+
if precompiled_charsmap:
|
863 |
+
list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
|
864 |
+
|
865 |
+
return normalizers.Sequence(list_normalizers)
|
866 |
+
|
867 |
+
def post_processor(self):
|
868 |
+
return processors.TemplateProcessing(
|
869 |
+
single="[CLS]:0 $A:0 [SEP]:0",
|
870 |
+
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
|
871 |
+
special_tokens=[
|
872 |
+
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
|
873 |
+
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
|
874 |
+
],
|
875 |
+
)
|
876 |
+
|
877 |
+
|
878 |
+
class BertGenerationConverter(SpmConverter):
|
879 |
+
pass
|
880 |
+
|
881 |
+
|
882 |
+
class PegasusConverter(SpmConverter):
|
883 |
+
def vocab(self, proto):
|
884 |
+
vocab = [
|
885 |
+
(self.original_tokenizer.pad_token, 0.0),
|
886 |
+
(self.original_tokenizer.eos_token, 0.0),
|
887 |
+
]
|
888 |
+
|
889 |
+
if self.original_tokenizer.mask_token_sent is not None:
|
890 |
+
vocab += [(self.original_tokenizer.mask_token_sent, 0.0)]
|
891 |
+
|
892 |
+
if (
|
893 |
+
self.original_tokenizer.mask_token is not None
|
894 |
+
and self.original_tokenizer.mask_token_id < self.original_tokenizer.offset
|
895 |
+
):
|
896 |
+
vocab += [(self.original_tokenizer.mask_token, 0.0)]
|
897 |
+
|
898 |
+
vocab += [(f"<unk_{i}>", -100.0) for i in range(2, self.original_tokenizer.offset)]
|
899 |
+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
|
900 |
+
return vocab
|
901 |
+
|
902 |
+
def unk_id(self, proto):
|
903 |
+
return proto.trainer_spec.unk_id + self.original_tokenizer.offset
|
904 |
+
|
905 |
+
def pre_tokenizer(self, replacement, add_prefix_space):
|
906 |
+
return pre_tokenizers.Sequence(
|
907 |
+
[
|
908 |
+
pre_tokenizers.WhitespaceSplit(),
|
909 |
+
pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
|
910 |
+
]
|
911 |
+
)
|
912 |
+
|
913 |
+
def post_processor(self):
|
914 |
+
eos = self.original_tokenizer.eos_token
|
915 |
+
special_tokens = [
|
916 |
+
(eos, self.original_tokenizer.eos_token_id),
|
917 |
+
]
|
918 |
+
return processors.TemplateProcessing(single=["$A", eos], pair=["$A", "$B", eos], special_tokens=special_tokens)
|
919 |
+
|
920 |
+
|
921 |
+
class T5Converter(SpmConverter):
|
922 |
+
def vocab(self, proto):
|
923 |
+
num_extra_ids = self.original_tokenizer._extra_ids
|
924 |
+
vocab = [(piece.piece, piece.score) for piece in proto.pieces]
|
925 |
+
vocab += [(f"<extra_id_{i}>", 0.0) for i in range(num_extra_ids - 1, -1, -1)]
|
926 |
+
return vocab
|
927 |
+
|
928 |
+
def post_processor(self):
|
929 |
+
return processors.TemplateProcessing(
|
930 |
+
single=["$A", "</s>"],
|
931 |
+
pair=["$A", "</s>", "$B", "</s>"],
|
932 |
+
special_tokens=[
|
933 |
+
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
934 |
+
],
|
935 |
+
)
|
936 |
+
|
937 |
+
|
938 |
+
class WhisperConverter(Converter):
|
939 |
+
def converted(self) -> Tokenizer:
|
940 |
+
vocab = self.original_tokenizer.encoder
|
941 |
+
merges = list(self.original_tokenizer.bpe_ranks.keys())
|
942 |
+
|
943 |
+
tokenizer = Tokenizer(
|
944 |
+
BPE(
|
945 |
+
vocab=vocab,
|
946 |
+
merges=merges,
|
947 |
+
dropout=None,
|
948 |
+
continuing_subword_prefix="",
|
949 |
+
end_of_word_suffix="",
|
950 |
+
fuse_unk=False,
|
951 |
+
)
|
952 |
+
)
|
953 |
+
|
954 |
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
|
955 |
+
tokenizer.decoder = decoders.ByteLevel()
|
956 |
+
|
957 |
+
prefix_token_ids = self.original_tokenizer.prefix_tokens
|
958 |
+
prefixes = self.original_tokenizer.convert_ids_to_tokens(prefix_token_ids)
|
959 |
+
eos = self.original_tokenizer.eos_token
|
960 |
+
eos_token_id = self.original_tokenizer.eos_token_id
|
961 |
+
prefix_template = " ".join([f"{token}:0" for token in prefixes])
|
962 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
963 |
+
single=f"{prefix_template} $A:0 {eos}:0",
|
964 |
+
pair=f"{prefix_template} $A:0 $B:1 {eos}:1",
|
965 |
+
special_tokens=[
|
966 |
+
(eos, eos_token_id),
|
967 |
+
*zip(prefixes, prefix_token_ids),
|
968 |
+
],
|
969 |
+
)
|
970 |
+
|
971 |
+
return tokenizer
|
972 |
+
|
973 |
+
|
974 |
+
class BigBirdConverter(SpmConverter):
|
975 |
+
def post_processor(self):
|
976 |
+
return processors.TemplateProcessing(
|
977 |
+
single="[CLS]:0 $A:0 [SEP]:0",
|
978 |
+
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
|
979 |
+
special_tokens=[
|
980 |
+
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
|
981 |
+
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
|
982 |
+
],
|
983 |
+
)
|
984 |
+
|
985 |
+
|
986 |
+
class CLIPConverter(Converter):
|
987 |
+
def converted(self) -> Tokenizer:
|
988 |
+
vocab = self.original_tokenizer.encoder
|
989 |
+
merges = list(self.original_tokenizer.bpe_ranks.keys())
|
990 |
+
unk_token = self.original_tokenizer.unk_token
|
991 |
+
|
992 |
+
tokenizer = Tokenizer(
|
993 |
+
BPE(
|
994 |
+
vocab=vocab,
|
995 |
+
merges=merges,
|
996 |
+
dropout=None,
|
997 |
+
continuing_subword_prefix="",
|
998 |
+
end_of_word_suffix="</w>",
|
999 |
+
fuse_unk=False,
|
1000 |
+
unk_token=str(unk_token),
|
1001 |
+
)
|
1002 |
+
)
|
1003 |
+
|
1004 |
+
tokenizer.normalizer = normalizers.Sequence(
|
1005 |
+
[normalizers.NFC(), normalizers.Replace(Regex(r"\s+"), " "), normalizers.Lowercase()]
|
1006 |
+
)
|
1007 |
+
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
1008 |
+
[
|
1009 |
+
pre_tokenizers.Split(
|
1010 |
+
Regex(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""),
|
1011 |
+
behavior="removed",
|
1012 |
+
invert=True,
|
1013 |
+
),
|
1014 |
+
pre_tokenizers.ByteLevel(add_prefix_space=False),
|
1015 |
+
]
|
1016 |
+
)
|
1017 |
+
tokenizer.decoder = decoders.ByteLevel()
|
1018 |
+
|
1019 |
+
# Hack to have a ByteLevel and TemplaceProcessor
|
1020 |
+
tokenizer.post_processor = processors.RobertaProcessing(
|
1021 |
+
sep=(self.original_tokenizer.eos_token, self.original_tokenizer.eos_token_id),
|
1022 |
+
cls=(self.original_tokenizer.bos_token, self.original_tokenizer.bos_token_id),
|
1023 |
+
add_prefix_space=False,
|
1024 |
+
trim_offsets=False,
|
1025 |
+
)
|
1026 |
+
return tokenizer
|
1027 |
+
|
1028 |
+
|
1029 |
+
class LayoutLMv2Converter(Converter):
|
1030 |
+
def converted(self) -> Tokenizer:
|
1031 |
+
vocab = self.original_tokenizer.vocab
|
1032 |
+
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
|
1033 |
+
|
1034 |
+
tokenize_chinese_chars = False
|
1035 |
+
strip_accents = False
|
1036 |
+
do_lower_case = True
|
1037 |
+
if hasattr(self.original_tokenizer, "basic_tokenizer"):
|
1038 |
+
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
|
1039 |
+
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
|
1040 |
+
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
|
1041 |
+
|
1042 |
+
tokenizer.normalizer = normalizers.BertNormalizer(
|
1043 |
+
clean_text=True,
|
1044 |
+
handle_chinese_chars=tokenize_chinese_chars,
|
1045 |
+
strip_accents=strip_accents,
|
1046 |
+
lowercase=do_lower_case,
|
1047 |
+
)
|
1048 |
+
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
1049 |
+
|
1050 |
+
cls = str(self.original_tokenizer.cls_token)
|
1051 |
+
sep = str(self.original_tokenizer.sep_token)
|
1052 |
+
cls_token_id = self.original_tokenizer.cls_token_id
|
1053 |
+
sep_token_id = self.original_tokenizer.sep_token_id
|
1054 |
+
|
1055 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
1056 |
+
single=f"{cls}:0 $A:0 {sep}:0",
|
1057 |
+
pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
|
1058 |
+
special_tokens=[
|
1059 |
+
(cls, cls_token_id),
|
1060 |
+
(sep, sep_token_id),
|
1061 |
+
],
|
1062 |
+
)
|
1063 |
+
tokenizer.decoder = decoders.WordPiece(prefix="##")
|
1064 |
+
|
1065 |
+
return tokenizer
|
1066 |
+
|
1067 |
+
|
1068 |
+
class BlenderbotConverter(Converter):
|
1069 |
+
def converted(self) -> Tokenizer:
|
1070 |
+
ot = self.original_tokenizer
|
1071 |
+
vocab = ot.encoder
|
1072 |
+
merges = list(ot.bpe_ranks.keys())
|
1073 |
+
|
1074 |
+
tokenizer = Tokenizer(
|
1075 |
+
BPE(
|
1076 |
+
vocab=vocab,
|
1077 |
+
merges=merges,
|
1078 |
+
dropout=None,
|
1079 |
+
continuing_subword_prefix="",
|
1080 |
+
end_of_word_suffix="",
|
1081 |
+
fuse_unk=False,
|
1082 |
+
)
|
1083 |
+
)
|
1084 |
+
|
1085 |
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
|
1086 |
+
tokenizer.decoder = decoders.ByteLevel()
|
1087 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
1088 |
+
single=f"$A:0 {ot.eos_token}:0",
|
1089 |
+
special_tokens=[
|
1090 |
+
(ot.eos_token, ot.eos_token_id),
|
1091 |
+
],
|
1092 |
+
)
|
1093 |
+
|
1094 |
+
return tokenizer
|
1095 |
+
|
1096 |
+
|
1097 |
+
class XGLMConverter(SpmConverter):
|
1098 |
+
def vocab(self, proto):
|
1099 |
+
vocab = [
|
1100 |
+
("<s>", 0.0),
|
1101 |
+
("<pad>", 0.0),
|
1102 |
+
("</s>", 0.0),
|
1103 |
+
("<unk>", 0.0),
|
1104 |
+
]
|
1105 |
+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
1106 |
+
# fmt: off
|
1107 |
+
vocab += [("<madeupword0>", 0.0), ("<madeupword1>", 0.0), ("<madeupword2>", 0.0), ("<madeupword3>", 0.0), ("<madeupword4>", 0.0), ("<madeupword5>", 0.0), ("<madeupword6>", 0.0)]
|
1108 |
+
# fmt: on
|
1109 |
+
return vocab
|
1110 |
+
|
1111 |
+
def unk_id(self, proto):
|
1112 |
+
unk_id = 3
|
1113 |
+
return unk_id
|
1114 |
+
|
1115 |
+
def post_processor(self):
|
1116 |
+
return processors.TemplateProcessing(
|
1117 |
+
single="</s> $A",
|
1118 |
+
pair="</s> $A </s> </s> $B",
|
1119 |
+
special_tokens=[
|
1120 |
+
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
|
1121 |
+
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
|
1122 |
+
],
|
1123 |
+
)
|
1124 |
+
|
1125 |
+
|
1126 |
+
class LlamaConverter(SpmConverter):
|
1127 |
+
handle_byte_fallback = True
|
1128 |
+
|
1129 |
+
def vocab(self, proto):
|
1130 |
+
vocab = [
|
1131 |
+
("<unk>", 0.0),
|
1132 |
+
("<s>", 0.0),
|
1133 |
+
("</s>", 0.0),
|
1134 |
+
]
|
1135 |
+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
|
1136 |
+
return vocab
|
1137 |
+
|
1138 |
+
def unk_id(self, proto):
|
1139 |
+
unk_id = 0
|
1140 |
+
return unk_id
|
1141 |
+
|
1142 |
+
def decoder(self, replacement, add_prefix_space):
|
1143 |
+
return decoders.Sequence(
|
1144 |
+
[
|
1145 |
+
decoders.Replace("▁", " "),
|
1146 |
+
decoders.ByteFallback(),
|
1147 |
+
decoders.Fuse(),
|
1148 |
+
decoders.Strip(content=" ", left=1),
|
1149 |
+
]
|
1150 |
+
)
|
1151 |
+
|
1152 |
+
def tokenizer(self, proto):
|
1153 |
+
model_type = proto.trainer_spec.model_type
|
1154 |
+
vocab_scores = self.vocab(proto)
|
1155 |
+
if model_type == 1:
|
1156 |
+
import tokenizers
|
1157 |
+
|
1158 |
+
if version.parse(tokenizers.__version__) < version.parse("0.14.0"):
|
1159 |
+
tokenizer = Tokenizer(Unigram(vocab_scores, 0))
|
1160 |
+
else:
|
1161 |
+
tokenizer = Tokenizer(Unigram(vocab_scores, 0, byte_fallback=True))
|
1162 |
+
|
1163 |
+
elif model_type == 2:
|
1164 |
+
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
|
1165 |
+
bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
|
1166 |
+
tokenizer = Tokenizer(
|
1167 |
+
BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
|
1168 |
+
)
|
1169 |
+
tokenizer.add_special_tokens(
|
1170 |
+
[
|
1171 |
+
AddedToken("<unk>"),
|
1172 |
+
AddedToken("<s>"),
|
1173 |
+
AddedToken("</s>"),
|
1174 |
+
]
|
1175 |
+
)
|
1176 |
+
else:
|
1177 |
+
raise Exception(
|
1178 |
+
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
1179 |
+
)
|
1180 |
+
|
1181 |
+
return tokenizer
|
1182 |
+
|
1183 |
+
def normalizer(self, proto):
|
1184 |
+
return normalizers.Sequence(
|
1185 |
+
[
|
1186 |
+
normalizers.Prepend(prepend="▁"),
|
1187 |
+
normalizers.Replace(pattern=" ", content="▁"),
|
1188 |
+
]
|
1189 |
+
)
|
1190 |
+
|
1191 |
+
def pre_tokenizer(self, replacement, add_prefix_space):
|
1192 |
+
return None
|
1193 |
+
|
1194 |
+
def post_processor(self):
|
1195 |
+
# the processor is defined in the LlamaTokenizerFast class.
|
1196 |
+
return None
|
1197 |
+
|
1198 |
+
|
1199 |
+
class MarkupLMConverter(Converter):
|
1200 |
+
def converted(self) -> Tokenizer:
|
1201 |
+
ot = self.original_tokenizer
|
1202 |
+
vocab = ot.encoder
|
1203 |
+
merges = list(ot.bpe_ranks.keys())
|
1204 |
+
|
1205 |
+
tokenizer = Tokenizer(
|
1206 |
+
BPE(
|
1207 |
+
vocab=vocab,
|
1208 |
+
merges=merges,
|
1209 |
+
dropout=None,
|
1210 |
+
continuing_subword_prefix="",
|
1211 |
+
end_of_word_suffix="",
|
1212 |
+
fuse_unk=False,
|
1213 |
+
unk_token=self.original_tokenizer.unk_token,
|
1214 |
+
)
|
1215 |
+
)
|
1216 |
+
|
1217 |
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
|
1218 |
+
tokenizer.decoder = decoders.ByteLevel()
|
1219 |
+
|
1220 |
+
cls = str(self.original_tokenizer.cls_token)
|
1221 |
+
sep = str(self.original_tokenizer.sep_token)
|
1222 |
+
cls_token_id = self.original_tokenizer.cls_token_id
|
1223 |
+
sep_token_id = self.original_tokenizer.sep_token_id
|
1224 |
+
|
1225 |
+
tokenizer.post_processor = processors.TemplateProcessing(
|
1226 |
+
single=f"{cls} $A {sep}",
|
1227 |
+
pair=f"{cls} $A {sep} $B {sep}",
|
1228 |
+
special_tokens=[
|
1229 |
+
(cls, cls_token_id),
|
1230 |
+
(sep, sep_token_id),
|
1231 |
+
],
|
1232 |
+
)
|
1233 |
+
|
1234 |
+
return tokenizer
|
1235 |
+
|
1236 |
+
|
1237 |
+
SLOW_TO_FAST_CONVERTERS = {
|
1238 |
+
"AlbertTokenizer": AlbertConverter,
|
1239 |
+
"BartTokenizer": RobertaConverter,
|
1240 |
+
"BarthezTokenizer": BarthezConverter,
|
1241 |
+
"BertTokenizer": BertConverter,
|
1242 |
+
"BigBirdTokenizer": BigBirdConverter,
|
1243 |
+
"BlenderbotTokenizer": BlenderbotConverter,
|
1244 |
+
"CamembertTokenizer": CamembertConverter,
|
1245 |
+
"CLIPTokenizer": CLIPConverter,
|
1246 |
+
"CodeGenTokenizer": GPT2Converter,
|
1247 |
+
"ConvBertTokenizer": BertConverter,
|
1248 |
+
"DebertaTokenizer": DebertaConverter,
|
1249 |
+
"DebertaV2Tokenizer": DebertaV2Converter,
|
1250 |
+
"DistilBertTokenizer": BertConverter,
|
1251 |
+
"DPRReaderTokenizer": BertConverter,
|
1252 |
+
"DPRQuestionEncoderTokenizer": BertConverter,
|
1253 |
+
"DPRContextEncoderTokenizer": BertConverter,
|
1254 |
+
"ElectraTokenizer": BertConverter,
|
1255 |
+
"FNetTokenizer": AlbertConverter,
|
1256 |
+
"FunnelTokenizer": FunnelConverter,
|
1257 |
+
"GPT2Tokenizer": GPT2Converter,
|
1258 |
+
"HerbertTokenizer": HerbertConverter,
|
1259 |
+
"LayoutLMTokenizer": BertConverter,
|
1260 |
+
"LayoutLMv2Tokenizer": BertConverter,
|
1261 |
+
"LayoutLMv3Tokenizer": RobertaConverter,
|
1262 |
+
"LayoutXLMTokenizer": XLMRobertaConverter,
|
1263 |
+
"LongformerTokenizer": RobertaConverter,
|
1264 |
+
"LEDTokenizer": RobertaConverter,
|
1265 |
+
"LxmertTokenizer": BertConverter,
|
1266 |
+
"MarkupLMTokenizer": MarkupLMConverter,
|
1267 |
+
"MBartTokenizer": MBartConverter,
|
1268 |
+
"MBart50Tokenizer": MBart50Converter,
|
1269 |
+
"MPNetTokenizer": MPNetConverter,
|
1270 |
+
"MobileBertTokenizer": BertConverter,
|
1271 |
+
"MvpTokenizer": RobertaConverter,
|
1272 |
+
"NllbTokenizer": NllbConverter,
|
1273 |
+
"OpenAIGPTTokenizer": OpenAIGPTConverter,
|
1274 |
+
"PegasusTokenizer": PegasusConverter,
|
1275 |
+
"RealmTokenizer": BertConverter,
|
1276 |
+
"ReformerTokenizer": ReformerConverter,
|
1277 |
+
"RemBertTokenizer": RemBertConverter,
|
1278 |
+
"RetriBertTokenizer": BertConverter,
|
1279 |
+
"RobertaTokenizer": RobertaConverter,
|
1280 |
+
"RoFormerTokenizer": RoFormerConverter,
|
1281 |
+
"SqueezeBertTokenizer": BertConverter,
|
1282 |
+
"T5Tokenizer": T5Converter,
|
1283 |
+
"WhisperTokenizer": WhisperConverter,
|
1284 |
+
"XLMRobertaTokenizer": XLMRobertaConverter,
|
1285 |
+
"XLNetTokenizer": XLNetConverter,
|
1286 |
+
"SplinterTokenizer": SplinterConverter,
|
1287 |
+
"XGLMTokenizer": XGLMConverter,
|
1288 |
+
"LlamaTokenizer": LlamaConverter,
|
1289 |
+
"CodeLlamaTokenizer": LlamaConverter,
|
1290 |
+
}
|
1291 |
+
|
1292 |
+
|
1293 |
+
def convert_slow_tokenizer(transformer_tokenizer) -> Tokenizer:
|
1294 |
+
"""
|
1295 |
+
Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
|
1296 |
+
|
1297 |
+
Args:
|
1298 |
+
transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
|
1299 |
+
Instance of a slow tokenizer to convert in the backend tokenizer for
|
1300 |
+
[`~tokenization_utils_base.PreTrainedTokenizerFast`].
|
1301 |
+
|
1302 |
+
Return:
|
1303 |
+
A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
|
1304 |
+
[`~tokenization_utils_base.PreTrainedTokenizerFast`]
|
1305 |
+
"""
|
1306 |
+
|
1307 |
+
tokenizer_class_name = transformer_tokenizer.__class__.__name__
|
1308 |
+
|
1309 |
+
if tokenizer_class_name not in SLOW_TO_FAST_CONVERTERS:
|
1310 |
+
raise ValueError(
|
1311 |
+
f"An instance of tokenizer class {tokenizer_class_name} cannot be converted in a Fast tokenizer instance."
|
1312 |
+
" No converter was found. Currently available slow->fast convertors:"
|
1313 |
+
f" {list(SLOW_TO_FAST_CONVERTERS.keys())}"
|
1314 |
+
)
|
1315 |
+
|
1316 |
+
converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
|
1317 |
+
|
1318 |
+
return converter_class(transformer_tokenizer).converted()
|
transformers_4_35_0/convert_slow_tokenizers_checkpoints_to_fast.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Convert slow tokenizers checkpoints in fast (serialization format of the `tokenizers` library)"""
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import os
|
19 |
+
|
20 |
+
import transformers
|
21 |
+
|
22 |
+
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
|
23 |
+
from .utils import logging
|
24 |
+
|
25 |
+
|
26 |
+
logging.set_verbosity_info()
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
TOKENIZER_CLASSES = {name: getattr(transformers, name + "Fast") for name in SLOW_TO_FAST_CONVERTERS}
|
32 |
+
|
33 |
+
|
34 |
+
def convert_slow_checkpoint_to_fast(tokenizer_name, checkpoint_name, dump_path, force_download):
|
35 |
+
if tokenizer_name is not None and tokenizer_name not in TOKENIZER_CLASSES:
|
36 |
+
raise ValueError(f"Unrecognized tokenizer name, should be one of {list(TOKENIZER_CLASSES.keys())}.")
|
37 |
+
|
38 |
+
if tokenizer_name is None:
|
39 |
+
tokenizer_names = TOKENIZER_CLASSES
|
40 |
+
else:
|
41 |
+
tokenizer_names = {tokenizer_name: getattr(transformers, tokenizer_name + "Fast")}
|
42 |
+
|
43 |
+
logger.info(f"Loading tokenizer classes: {tokenizer_names}")
|
44 |
+
|
45 |
+
for tokenizer_name in tokenizer_names:
|
46 |
+
tokenizer_class = TOKENIZER_CLASSES[tokenizer_name]
|
47 |
+
|
48 |
+
add_prefix = True
|
49 |
+
if checkpoint_name is None:
|
50 |
+
checkpoint_names = list(tokenizer_class.max_model_input_sizes.keys())
|
51 |
+
else:
|
52 |
+
checkpoint_names = [checkpoint_name]
|
53 |
+
|
54 |
+
logger.info(f"For tokenizer {tokenizer_class.__class__.__name__} loading checkpoints: {checkpoint_names}")
|
55 |
+
|
56 |
+
for checkpoint in checkpoint_names:
|
57 |
+
logger.info(f"Loading {tokenizer_class.__class__.__name__} {checkpoint}")
|
58 |
+
|
59 |
+
# Load tokenizer
|
60 |
+
tokenizer = tokenizer_class.from_pretrained(checkpoint, force_download=force_download)
|
61 |
+
|
62 |
+
# Save fast tokenizer
|
63 |
+
logger.info(f"Save fast tokenizer to {dump_path} with prefix {checkpoint} add_prefix {add_prefix}")
|
64 |
+
|
65 |
+
# For organization names we create sub-directories
|
66 |
+
if "/" in checkpoint:
|
67 |
+
checkpoint_directory, checkpoint_prefix_name = checkpoint.split("/")
|
68 |
+
dump_path_full = os.path.join(dump_path, checkpoint_directory)
|
69 |
+
elif add_prefix:
|
70 |
+
checkpoint_prefix_name = checkpoint
|
71 |
+
dump_path_full = dump_path
|
72 |
+
else:
|
73 |
+
checkpoint_prefix_name = None
|
74 |
+
dump_path_full = dump_path
|
75 |
+
|
76 |
+
logger.info(f"=> {dump_path_full} with prefix {checkpoint_prefix_name}, add_prefix {add_prefix}")
|
77 |
+
|
78 |
+
if checkpoint in list(tokenizer.pretrained_vocab_files_map.values())[0]:
|
79 |
+
file_path = list(tokenizer.pretrained_vocab_files_map.values())[0][checkpoint]
|
80 |
+
next_char = file_path.split(checkpoint)[-1][0]
|
81 |
+
if next_char == "/":
|
82 |
+
dump_path_full = os.path.join(dump_path_full, checkpoint_prefix_name)
|
83 |
+
checkpoint_prefix_name = None
|
84 |
+
|
85 |
+
logger.info(f"=> {dump_path_full} with prefix {checkpoint_prefix_name}, add_prefix {add_prefix}")
|
86 |
+
|
87 |
+
file_names = tokenizer.save_pretrained(
|
88 |
+
dump_path_full, legacy_format=False, filename_prefix=checkpoint_prefix_name
|
89 |
+
)
|
90 |
+
logger.info(f"=> File names {file_names}")
|
91 |
+
|
92 |
+
for file_name in file_names:
|
93 |
+
if not file_name.endswith("tokenizer.json"):
|
94 |
+
os.remove(file_name)
|
95 |
+
logger.info(f"=> removing {file_name}")
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == "__main__":
|
99 |
+
parser = argparse.ArgumentParser()
|
100 |
+
# Required parameters
|
101 |
+
parser.add_argument(
|
102 |
+
"--dump_path", default=None, type=str, required=True, help="Path to output generated fast tokenizer files."
|
103 |
+
)
|
104 |
+
parser.add_argument(
|
105 |
+
"--tokenizer_name",
|
106 |
+
default=None,
|
107 |
+
type=str,
|
108 |
+
help=(
|
109 |
+
f"Optional tokenizer type selected in the list of {list(TOKENIZER_CLASSES.keys())}. If not given, will "
|
110 |
+
"download and convert all the checkpoints from AWS."
|
111 |
+
),
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"--checkpoint_name",
|
115 |
+
default=None,
|
116 |
+
type=str,
|
117 |
+
help="Optional checkpoint name. If not given, will download and convert the canonical checkpoints from AWS.",
|
118 |
+
)
|
119 |
+
parser.add_argument(
|
120 |
+
"--force_download",
|
121 |
+
action="store_true",
|
122 |
+
help="Re-download checkpoints.",
|
123 |
+
)
|
124 |
+
args = parser.parse_args()
|
125 |
+
|
126 |
+
convert_slow_checkpoint_to_fast(args.tokenizer_name, args.checkpoint_name, args.dump_path, args.force_download)
|
transformers_4_35_0/convert_tf_hub_seq_to_seq_bert_to_pytorch.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Convert Seq2Seq TF Hub checkpoint."""
|
16 |
+
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
|
20 |
+
from . import (
|
21 |
+
BertConfig,
|
22 |
+
BertGenerationConfig,
|
23 |
+
BertGenerationDecoder,
|
24 |
+
BertGenerationEncoder,
|
25 |
+
load_tf_weights_in_bert_generation,
|
26 |
+
logging,
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
logging.set_verbosity_info()
|
31 |
+
|
32 |
+
|
33 |
+
def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_named_decoder, vocab_size, is_encoder):
|
34 |
+
# Initialise PyTorch model
|
35 |
+
bert_config = BertConfig.from_pretrained(
|
36 |
+
"bert-large-cased",
|
37 |
+
vocab_size=vocab_size,
|
38 |
+
max_position_embeddings=512,
|
39 |
+
is_decoder=True,
|
40 |
+
add_cross_attention=True,
|
41 |
+
)
|
42 |
+
bert_config_dict = bert_config.to_dict()
|
43 |
+
del bert_config_dict["type_vocab_size"]
|
44 |
+
config = BertGenerationConfig(**bert_config_dict)
|
45 |
+
if is_encoder:
|
46 |
+
model = BertGenerationEncoder(config)
|
47 |
+
else:
|
48 |
+
model = BertGenerationDecoder(config)
|
49 |
+
print(f"Building PyTorch model from configuration: {config}")
|
50 |
+
|
51 |
+
# Load weights from tf checkpoint
|
52 |
+
load_tf_weights_in_bert_generation(
|
53 |
+
model,
|
54 |
+
tf_hub_path,
|
55 |
+
model_class="bert",
|
56 |
+
is_encoder_named_decoder=is_encoder_named_decoder,
|
57 |
+
is_encoder=is_encoder,
|
58 |
+
)
|
59 |
+
|
60 |
+
# Save pytorch-model
|
61 |
+
print(f"Save PyTorch model and config to {pytorch_dump_path}")
|
62 |
+
model.save_pretrained(pytorch_dump_path)
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
parser = argparse.ArgumentParser()
|
67 |
+
# Required parameters
|
68 |
+
parser.add_argument(
|
69 |
+
"--tf_hub_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--is_encoder_named_decoder",
|
76 |
+
action="store_true",
|
77 |
+
help="If decoder has to be renamed to encoder in PyTorch model.",
|
78 |
+
)
|
79 |
+
parser.add_argument("--is_encoder", action="store_true", help="If model is an encoder.")
|
80 |
+
parser.add_argument("--vocab_size", default=50358, type=int, help="Vocab size of model")
|
81 |
+
args = parser.parse_args()
|
82 |
+
convert_tf_checkpoint_to_pytorch(
|
83 |
+
args.tf_hub_path,
|
84 |
+
args.pytorch_dump_path,
|
85 |
+
args.is_encoder_named_decoder,
|
86 |
+
args.vocab_size,
|
87 |
+
is_encoder=args.is_encoder,
|
88 |
+
)
|
transformers_4_35_0/data/__init__.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .data_collator import (
|
16 |
+
DataCollatorForLanguageModeling,
|
17 |
+
DataCollatorForPermutationLanguageModeling,
|
18 |
+
DataCollatorForSeq2Seq,
|
19 |
+
DataCollatorForSOP,
|
20 |
+
DataCollatorForTokenClassification,
|
21 |
+
DataCollatorForWholeWordMask,
|
22 |
+
DataCollatorWithPadding,
|
23 |
+
DefaultDataCollator,
|
24 |
+
default_data_collator,
|
25 |
+
)
|
26 |
+
from .metrics import glue_compute_metrics, xnli_compute_metrics
|
27 |
+
from .processors import (
|
28 |
+
DataProcessor,
|
29 |
+
InputExample,
|
30 |
+
InputFeatures,
|
31 |
+
SingleSentenceClassificationProcessor,
|
32 |
+
SquadExample,
|
33 |
+
SquadFeatures,
|
34 |
+
SquadV1Processor,
|
35 |
+
SquadV2Processor,
|
36 |
+
glue_convert_examples_to_features,
|
37 |
+
glue_output_modes,
|
38 |
+
glue_processors,
|
39 |
+
glue_tasks_num_labels,
|
40 |
+
squad_convert_examples_to_features,
|
41 |
+
xnli_output_modes,
|
42 |
+
xnli_processors,
|
43 |
+
xnli_tasks_num_labels,
|
44 |
+
)
|
transformers_4_35_0/data/data_collator.py
ADDED
@@ -0,0 +1,1535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import random
|
16 |
+
import warnings
|
17 |
+
from collections.abc import Mapping
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from random import randint
|
20 |
+
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
from ..models.bert import BertTokenizer, BertTokenizerFast
|
25 |
+
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
26 |
+
from ..utils import PaddingStrategy
|
27 |
+
|
28 |
+
|
29 |
+
InputDataClass = NewType("InputDataClass", Any)
|
30 |
+
|
31 |
+
"""
|
32 |
+
A DataCollator is a function that takes a list of samples from a Dataset and collate them into a batch, as a dictionary
|
33 |
+
of PyTorch/TensorFlow tensors or NumPy arrays.
|
34 |
+
"""
|
35 |
+
DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, Any]])
|
36 |
+
|
37 |
+
|
38 |
+
class DataCollatorMixin:
|
39 |
+
def __call__(self, features, return_tensors=None):
|
40 |
+
if return_tensors is None:
|
41 |
+
return_tensors = self.return_tensors
|
42 |
+
if return_tensors == "tf":
|
43 |
+
return self.tf_call(features)
|
44 |
+
elif return_tensors == "pt":
|
45 |
+
return self.torch_call(features)
|
46 |
+
elif return_tensors == "np":
|
47 |
+
return self.numpy_call(features)
|
48 |
+
else:
|
49 |
+
raise ValueError(f"Framework '{return_tensors}' not recognized!")
|
50 |
+
|
51 |
+
|
52 |
+
def default_data_collator(features: List[InputDataClass], return_tensors="pt") -> Dict[str, Any]:
|
53 |
+
"""
|
54 |
+
Very simple data collator that simply collates batches of dict-like objects and performs special handling for
|
55 |
+
potential keys named:
|
56 |
+
|
57 |
+
- `label`: handles a single value (int or float) per object
|
58 |
+
- `label_ids`: handles a list of values per object
|
59 |
+
|
60 |
+
Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
|
61 |
+
to the model. See glue and ner for example of how it's useful.
|
62 |
+
"""
|
63 |
+
|
64 |
+
# In this function we'll make the assumption that all `features` in the batch
|
65 |
+
# have the same attributes.
|
66 |
+
# So we will look at the first element as a proxy for what attributes exist
|
67 |
+
# on the whole batch.
|
68 |
+
|
69 |
+
if return_tensors == "pt":
|
70 |
+
return torch_default_data_collator(features)
|
71 |
+
elif return_tensors == "tf":
|
72 |
+
return tf_default_data_collator(features)
|
73 |
+
elif return_tensors == "np":
|
74 |
+
return numpy_default_data_collator(features)
|
75 |
+
|
76 |
+
|
77 |
+
@dataclass
|
78 |
+
class DefaultDataCollator(DataCollatorMixin):
|
79 |
+
"""
|
80 |
+
Very simple data collator that simply collates batches of dict-like objects and performs special handling for
|
81 |
+
potential keys named:
|
82 |
+
|
83 |
+
- `label`: handles a single value (int or float) per object
|
84 |
+
- `label_ids`: handles a list of values per object
|
85 |
+
|
86 |
+
Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
|
87 |
+
to the model. See glue and ner for example of how it's useful.
|
88 |
+
|
89 |
+
This is an object (like other data collators) rather than a pure function like default_data_collator. This can be
|
90 |
+
helpful if you need to set a return_tensors value at initialization.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
94 |
+
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
95 |
+
"""
|
96 |
+
|
97 |
+
return_tensors: str = "pt"
|
98 |
+
|
99 |
+
def __call__(self, features: List[Dict[str, Any]], return_tensors=None) -> Dict[str, Any]:
|
100 |
+
if return_tensors is None:
|
101 |
+
return_tensors = self.return_tensors
|
102 |
+
return default_data_collator(features, return_tensors)
|
103 |
+
|
104 |
+
|
105 |
+
def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
|
106 |
+
import torch
|
107 |
+
|
108 |
+
if not isinstance(features[0], Mapping):
|
109 |
+
features = [vars(f) for f in features]
|
110 |
+
first = features[0]
|
111 |
+
batch = {}
|
112 |
+
|
113 |
+
# Special handling for labels.
|
114 |
+
# Ensure that tensor is created with the correct type
|
115 |
+
# (it should be automatically the case, but let's make sure of it.)
|
116 |
+
if "label" in first and first["label"] is not None:
|
117 |
+
label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
|
118 |
+
dtype = torch.long if isinstance(label, int) else torch.float
|
119 |
+
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
120 |
+
elif "label_ids" in first and first["label_ids"] is not None:
|
121 |
+
if isinstance(first["label_ids"], torch.Tensor):
|
122 |
+
batch["labels"] = torch.stack([f["label_ids"] for f in features])
|
123 |
+
else:
|
124 |
+
dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
|
125 |
+
batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
|
126 |
+
|
127 |
+
# Handling of all other possible keys.
|
128 |
+
# Again, we will use the first element to figure out which key/values are not None for this model.
|
129 |
+
for k, v in first.items():
|
130 |
+
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
131 |
+
if isinstance(v, torch.Tensor):
|
132 |
+
batch[k] = torch.stack([f[k] for f in features])
|
133 |
+
elif isinstance(v, np.ndarray):
|
134 |
+
batch[k] = torch.tensor(np.stack([f[k] for f in features]))
|
135 |
+
else:
|
136 |
+
batch[k] = torch.tensor([f[k] for f in features])
|
137 |
+
|
138 |
+
return batch
|
139 |
+
|
140 |
+
|
141 |
+
def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
|
142 |
+
import tensorflow as tf
|
143 |
+
|
144 |
+
if not isinstance(features[0], Mapping):
|
145 |
+
features = [vars(f) for f in features]
|
146 |
+
first = features[0]
|
147 |
+
batch = {}
|
148 |
+
|
149 |
+
# Special handling for labels.
|
150 |
+
# Ensure that tensor is created with the correct type
|
151 |
+
# (it should be automatically the case, but let's make sure of it.)
|
152 |
+
if "label" in first and first["label"] is not None:
|
153 |
+
label_col_name = "label"
|
154 |
+
elif "label_ids" in first and first["label_ids"] is not None:
|
155 |
+
label_col_name = "label_ids"
|
156 |
+
elif "labels" in first and first["labels"] is not None:
|
157 |
+
label_col_name = "labels"
|
158 |
+
else:
|
159 |
+
label_col_name = None
|
160 |
+
if label_col_name is not None:
|
161 |
+
if isinstance(first[label_col_name], tf.Tensor):
|
162 |
+
dtype = tf.int64 if first[label_col_name].dtype.is_integer else tf.float32
|
163 |
+
elif isinstance(first[label_col_name], np.ndarray) or isinstance(first[label_col_name], np.generic):
|
164 |
+
dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32
|
165 |
+
elif isinstance(first[label_col_name], (tuple, list)):
|
166 |
+
dtype = tf.int64 if isinstance(first[label_col_name][0], int) else tf.float32
|
167 |
+
else:
|
168 |
+
dtype = tf.int64 if isinstance(first[label_col_name], int) else tf.float32
|
169 |
+
batch["labels"] = tf.convert_to_tensor([f[label_col_name] for f in features], dtype=dtype)
|
170 |
+
# Handling of all other possible keys.
|
171 |
+
# Again, we will use the first element to figure out which key/values are not None for this model.
|
172 |
+
for k, v in first.items():
|
173 |
+
if k not in ("label", "label_ids", "labels") and v is not None and not isinstance(v, str):
|
174 |
+
if isinstance(v, (tf.Tensor, np.ndarray)):
|
175 |
+
batch[k] = tf.stack([f[k] for f in features])
|
176 |
+
else:
|
177 |
+
batch[k] = tf.convert_to_tensor([f[k] for f in features])
|
178 |
+
|
179 |
+
return batch
|
180 |
+
|
181 |
+
|
182 |
+
def numpy_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
|
183 |
+
if not isinstance(features[0], Mapping):
|
184 |
+
features = [vars(f) for f in features]
|
185 |
+
first = features[0]
|
186 |
+
batch = {}
|
187 |
+
|
188 |
+
# Special handling for labels.
|
189 |
+
# Ensure that tensor is created with the correct type
|
190 |
+
# (it should be automatically the case, but let's make sure of it.)
|
191 |
+
if "label" in first and first["label"] is not None:
|
192 |
+
label = first["label"].item() if isinstance(first["label"], np.ndarray) else first["label"]
|
193 |
+
dtype = np.int64 if isinstance(label, int) else np.float32
|
194 |
+
batch["labels"] = np.array([f["label"] for f in features], dtype=dtype)
|
195 |
+
elif "label_ids" in first and first["label_ids"] is not None:
|
196 |
+
if isinstance(first["label_ids"], np.ndarray):
|
197 |
+
batch["labels"] = np.stack([f["label_ids"] for f in features])
|
198 |
+
else:
|
199 |
+
dtype = np.int64 if type(first["label_ids"][0]) is int else np.float32
|
200 |
+
batch["labels"] = np.array([f["label_ids"] for f in features], dtype=dtype)
|
201 |
+
|
202 |
+
# Handling of all other possible keys.
|
203 |
+
# Again, we will use the first element to figure out which key/values are not None for this model.
|
204 |
+
for k, v in first.items():
|
205 |
+
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
206 |
+
if isinstance(v, np.ndarray):
|
207 |
+
batch[k] = np.stack([f[k] for f in features])
|
208 |
+
else:
|
209 |
+
batch[k] = np.array([f[k] for f in features])
|
210 |
+
|
211 |
+
return batch
|
212 |
+
|
213 |
+
|
214 |
+
@dataclass
|
215 |
+
class DataCollatorWithPadding:
|
216 |
+
"""
|
217 |
+
Data collator that will dynamically pad the inputs received.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
221 |
+
The tokenizer used for encoding the data.
|
222 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
223 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
224 |
+
among:
|
225 |
+
|
226 |
+
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
|
227 |
+
sequence is provided).
|
228 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
229 |
+
acceptable input length for the model if that argument is not provided.
|
230 |
+
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
|
231 |
+
max_length (`int`, *optional*):
|
232 |
+
Maximum length of the returned list and optionally padding length (see above).
|
233 |
+
pad_to_multiple_of (`int`, *optional*):
|
234 |
+
If set will pad the sequence to a multiple of the provided value.
|
235 |
+
|
236 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
237 |
+
7.5 (Volta).
|
238 |
+
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
239 |
+
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
240 |
+
"""
|
241 |
+
|
242 |
+
tokenizer: PreTrainedTokenizerBase
|
243 |
+
padding: Union[bool, str, PaddingStrategy] = True
|
244 |
+
max_length: Optional[int] = None
|
245 |
+
pad_to_multiple_of: Optional[int] = None
|
246 |
+
return_tensors: str = "pt"
|
247 |
+
|
248 |
+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
|
249 |
+
batch = self.tokenizer.pad(
|
250 |
+
features,
|
251 |
+
padding=self.padding,
|
252 |
+
max_length=self.max_length,
|
253 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
254 |
+
return_tensors=self.return_tensors,
|
255 |
+
)
|
256 |
+
if "label" in batch:
|
257 |
+
batch["labels"] = batch["label"]
|
258 |
+
del batch["label"]
|
259 |
+
if "label_ids" in batch:
|
260 |
+
batch["labels"] = batch["label_ids"]
|
261 |
+
del batch["label_ids"]
|
262 |
+
return batch
|
263 |
+
|
264 |
+
|
265 |
+
@dataclass
|
266 |
+
class DataCollatorForTokenClassification(DataCollatorMixin):
|
267 |
+
"""
|
268 |
+
Data collator that will dynamically pad the inputs received, as well as the labels.
|
269 |
+
|
270 |
+
Args:
|
271 |
+
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
272 |
+
The tokenizer used for encoding the data.
|
273 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
274 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
275 |
+
among:
|
276 |
+
|
277 |
+
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
|
278 |
+
sequence is provided).
|
279 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
280 |
+
acceptable input length for the model if that argument is not provided.
|
281 |
+
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
|
282 |
+
max_length (`int`, *optional*):
|
283 |
+
Maximum length of the returned list and optionally padding length (see above).
|
284 |
+
pad_to_multiple_of (`int`, *optional*):
|
285 |
+
If set will pad the sequence to a multiple of the provided value.
|
286 |
+
|
287 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
288 |
+
7.5 (Volta).
|
289 |
+
label_pad_token_id (`int`, *optional*, defaults to -100):
|
290 |
+
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
|
291 |
+
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
292 |
+
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
293 |
+
"""
|
294 |
+
|
295 |
+
tokenizer: PreTrainedTokenizerBase
|
296 |
+
padding: Union[bool, str, PaddingStrategy] = True
|
297 |
+
max_length: Optional[int] = None
|
298 |
+
pad_to_multiple_of: Optional[int] = None
|
299 |
+
label_pad_token_id: int = -100
|
300 |
+
return_tensors: str = "pt"
|
301 |
+
|
302 |
+
def torch_call(self, features):
|
303 |
+
import torch
|
304 |
+
|
305 |
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
306 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
307 |
+
|
308 |
+
no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
|
309 |
+
|
310 |
+
batch = self.tokenizer.pad(
|
311 |
+
no_labels_features,
|
312 |
+
padding=self.padding,
|
313 |
+
max_length=self.max_length,
|
314 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
315 |
+
return_tensors="pt",
|
316 |
+
)
|
317 |
+
|
318 |
+
if labels is None:
|
319 |
+
return batch
|
320 |
+
|
321 |
+
sequence_length = batch["input_ids"].shape[1]
|
322 |
+
padding_side = self.tokenizer.padding_side
|
323 |
+
|
324 |
+
def to_list(tensor_or_iterable):
|
325 |
+
if isinstance(tensor_or_iterable, torch.Tensor):
|
326 |
+
return tensor_or_iterable.tolist()
|
327 |
+
return list(tensor_or_iterable)
|
328 |
+
|
329 |
+
if padding_side == "right":
|
330 |
+
batch[label_name] = [
|
331 |
+
to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
|
332 |
+
]
|
333 |
+
else:
|
334 |
+
batch[label_name] = [
|
335 |
+
[self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
|
336 |
+
]
|
337 |
+
|
338 |
+
batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
|
339 |
+
return batch
|
340 |
+
|
341 |
+
def tf_call(self, features):
|
342 |
+
import tensorflow as tf
|
343 |
+
|
344 |
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
345 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
346 |
+
batch = self.tokenizer.pad(
|
347 |
+
features,
|
348 |
+
padding=self.padding,
|
349 |
+
max_length=self.max_length,
|
350 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
351 |
+
# Conversion to tensors will fail if we have labels as they are not of the same length yet.
|
352 |
+
return_tensors="tf" if labels is None else None,
|
353 |
+
)
|
354 |
+
|
355 |
+
if labels is None:
|
356 |
+
return batch
|
357 |
+
|
358 |
+
sequence_length = tf.convert_to_tensor(batch["input_ids"]).shape[1]
|
359 |
+
padding_side = self.tokenizer.padding_side
|
360 |
+
if padding_side == "right":
|
361 |
+
batch["labels"] = [
|
362 |
+
list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
|
363 |
+
]
|
364 |
+
else:
|
365 |
+
batch["labels"] = [
|
366 |
+
[self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
|
367 |
+
]
|
368 |
+
|
369 |
+
batch = {k: tf.convert_to_tensor(v, dtype=tf.int64) for k, v in batch.items()}
|
370 |
+
return batch
|
371 |
+
|
372 |
+
def numpy_call(self, features):
|
373 |
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
374 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
375 |
+
batch = self.tokenizer.pad(
|
376 |
+
features,
|
377 |
+
padding=self.padding,
|
378 |
+
max_length=self.max_length,
|
379 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
380 |
+
# Conversion to tensors will fail if we have labels as they are not of the same length yet.
|
381 |
+
return_tensors="np" if labels is None else None,
|
382 |
+
)
|
383 |
+
|
384 |
+
if labels is None:
|
385 |
+
return batch
|
386 |
+
|
387 |
+
sequence_length = np.array(batch["input_ids"]).shape[1]
|
388 |
+
padding_side = self.tokenizer.padding_side
|
389 |
+
if padding_side == "right":
|
390 |
+
batch["labels"] = [
|
391 |
+
list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
|
392 |
+
]
|
393 |
+
else:
|
394 |
+
batch["labels"] = [
|
395 |
+
[self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
|
396 |
+
]
|
397 |
+
|
398 |
+
batch = {k: np.array(v, dtype=np.int64) for k, v in batch.items()}
|
399 |
+
return batch
|
400 |
+
|
401 |
+
|
402 |
+
def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
|
403 |
+
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
404 |
+
import torch
|
405 |
+
|
406 |
+
# Tensorize if necessary.
|
407 |
+
if isinstance(examples[0], (list, tuple, np.ndarray)):
|
408 |
+
examples = [torch.tensor(e, dtype=torch.long) for e in examples]
|
409 |
+
|
410 |
+
length_of_first = examples[0].size(0)
|
411 |
+
|
412 |
+
# Check if padding is necessary.
|
413 |
+
|
414 |
+
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
|
415 |
+
if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
|
416 |
+
return torch.stack(examples, dim=0)
|
417 |
+
|
418 |
+
# If yes, check if we have a `pad_token`.
|
419 |
+
if tokenizer._pad_token is None:
|
420 |
+
raise ValueError(
|
421 |
+
"You are attempting to pad samples but the tokenizer you are using"
|
422 |
+
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
423 |
+
)
|
424 |
+
|
425 |
+
# Creating the full tensor and filling it with our data.
|
426 |
+
max_length = max(x.size(0) for x in examples)
|
427 |
+
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
428 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
429 |
+
result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
|
430 |
+
for i, example in enumerate(examples):
|
431 |
+
if tokenizer.padding_side == "right":
|
432 |
+
result[i, : example.shape[0]] = example
|
433 |
+
else:
|
434 |
+
result[i, -example.shape[0] :] = example
|
435 |
+
return result
|
436 |
+
|
437 |
+
|
438 |
+
def _tf_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
|
439 |
+
import tensorflow as tf
|
440 |
+
|
441 |
+
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
442 |
+
# Tensorize if necessary.
|
443 |
+
if isinstance(examples[0], (list, tuple)):
|
444 |
+
examples = [tf.convert_to_tensor(e, dtype=tf.int64) for e in examples]
|
445 |
+
|
446 |
+
# Check if padding is necessary.
|
447 |
+
length_of_first = len(examples[0])
|
448 |
+
are_tensors_same_length = all(len(x) == length_of_first for x in examples)
|
449 |
+
if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
|
450 |
+
return tf.stack(examples, axis=0)
|
451 |
+
|
452 |
+
# If yes, check if we have a `pad_token`.
|
453 |
+
if tokenizer._pad_token is None:
|
454 |
+
raise ValueError(
|
455 |
+
"You are attempting to pad samples but the tokenizer you are using"
|
456 |
+
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
457 |
+
)
|
458 |
+
|
459 |
+
# Creating the full tensor and filling it with our data.
|
460 |
+
max_length = max(len(x) for x in examples)
|
461 |
+
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
462 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
463 |
+
# result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
|
464 |
+
result = []
|
465 |
+
rank = tf.rank(examples[0])
|
466 |
+
paddings = np.zeros((rank, 2), dtype=np.int32)
|
467 |
+
for example in examples:
|
468 |
+
if tokenizer.padding_side == "right":
|
469 |
+
paddings[0, 1] = max_length - len(example)
|
470 |
+
else:
|
471 |
+
paddings[0, 0] = max_length - len(example)
|
472 |
+
result.append(tf.pad(example, paddings, constant_values=tokenizer.pad_token_id))
|
473 |
+
return tf.stack(result, axis=0)
|
474 |
+
|
475 |
+
|
476 |
+
def _numpy_collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
|
477 |
+
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
|
478 |
+
# Tensorize if necessary.
|
479 |
+
if isinstance(examples[0], (list, tuple)):
|
480 |
+
examples = [np.array(e, dtype=np.int64) for e in examples]
|
481 |
+
|
482 |
+
# Check if padding is necessary.
|
483 |
+
length_of_first = len(examples[0])
|
484 |
+
are_tensors_same_length = all(len(x) == length_of_first for x in examples)
|
485 |
+
if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
|
486 |
+
return np.stack(examples, axis=0)
|
487 |
+
|
488 |
+
# If yes, check if we have a `pad_token`.
|
489 |
+
if tokenizer._pad_token is None:
|
490 |
+
raise ValueError(
|
491 |
+
"You are attempting to pad samples but the tokenizer you are using"
|
492 |
+
f" ({tokenizer.__class__.__name__}) does not have a pad token."
|
493 |
+
)
|
494 |
+
|
495 |
+
# Creating the full tensor and filling it with our data.
|
496 |
+
max_length = max(len(x) for x in examples)
|
497 |
+
if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
498 |
+
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
499 |
+
result = np.full(shape=(len(examples), max_length), fill_value=tokenizer.pad_token_id, dtype=examples[0].dtype)
|
500 |
+
for i, example in enumerate(examples):
|
501 |
+
if tokenizer.padding_side == "right":
|
502 |
+
result[i, : example.shape[0]] = example
|
503 |
+
else:
|
504 |
+
result[i, -example.shape[0] :] = example
|
505 |
+
return result
|
506 |
+
|
507 |
+
|
508 |
+
def tolist(x):
|
509 |
+
if isinstance(x, list):
|
510 |
+
return x
|
511 |
+
elif hasattr(x, "numpy"): # Checks for TF tensors without needing the import
|
512 |
+
x = x.numpy()
|
513 |
+
return x.tolist()
|
514 |
+
|
515 |
+
|
516 |
+
@dataclass
|
517 |
+
class DataCollatorForSeq2Seq:
|
518 |
+
"""
|
519 |
+
Data collator that will dynamically pad the inputs received, as well as the labels.
|
520 |
+
|
521 |
+
Args:
|
522 |
+
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
523 |
+
The tokenizer used for encoding the data.
|
524 |
+
model ([`PreTrainedModel`], *optional*):
|
525 |
+
The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
|
526 |
+
prepare the *decoder_input_ids*
|
527 |
+
|
528 |
+
This is useful when using *label_smoothing* to avoid calculating loss twice.
|
529 |
+
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
530 |
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
531 |
+
among:
|
532 |
+
|
533 |
+
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
|
534 |
+
sequence is provided).
|
535 |
+
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
536 |
+
acceptable input length for the model if that argument is not provided.
|
537 |
+
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
|
538 |
+
max_length (`int`, *optional*):
|
539 |
+
Maximum length of the returned list and optionally padding length (see above).
|
540 |
+
pad_to_multiple_of (`int`, *optional*):
|
541 |
+
If set will pad the sequence to a multiple of the provided value.
|
542 |
+
|
543 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
544 |
+
7.5 (Volta).
|
545 |
+
label_pad_token_id (`int`, *optional*, defaults to -100):
|
546 |
+
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
547 |
+
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
548 |
+
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
549 |
+
"""
|
550 |
+
|
551 |
+
tokenizer: PreTrainedTokenizerBase
|
552 |
+
model: Optional[Any] = None
|
553 |
+
padding: Union[bool, str, PaddingStrategy] = True
|
554 |
+
max_length: Optional[int] = None
|
555 |
+
pad_to_multiple_of: Optional[int] = None
|
556 |
+
label_pad_token_id: int = -100
|
557 |
+
return_tensors: str = "pt"
|
558 |
+
|
559 |
+
def __call__(self, features, return_tensors=None):
|
560 |
+
if return_tensors is None:
|
561 |
+
return_tensors = self.return_tensors
|
562 |
+
labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
|
563 |
+
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
|
564 |
+
# same length to return tensors.
|
565 |
+
if labels is not None:
|
566 |
+
max_label_length = max(len(l) for l in labels)
|
567 |
+
if self.pad_to_multiple_of is not None:
|
568 |
+
max_label_length = (
|
569 |
+
(max_label_length + self.pad_to_multiple_of - 1)
|
570 |
+
// self.pad_to_multiple_of
|
571 |
+
* self.pad_to_multiple_of
|
572 |
+
)
|
573 |
+
|
574 |
+
padding_side = self.tokenizer.padding_side
|
575 |
+
for feature in features:
|
576 |
+
remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
|
577 |
+
if isinstance(feature["labels"], list):
|
578 |
+
feature["labels"] = (
|
579 |
+
feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
|
580 |
+
)
|
581 |
+
elif padding_side == "right":
|
582 |
+
feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
|
583 |
+
else:
|
584 |
+
feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
|
585 |
+
|
586 |
+
features = self.tokenizer.pad(
|
587 |
+
features,
|
588 |
+
padding=self.padding,
|
589 |
+
max_length=self.max_length,
|
590 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
591 |
+
return_tensors=return_tensors,
|
592 |
+
)
|
593 |
+
|
594 |
+
# prepare decoder_input_ids
|
595 |
+
if (
|
596 |
+
labels is not None
|
597 |
+
and self.model is not None
|
598 |
+
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
|
599 |
+
):
|
600 |
+
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"])
|
601 |
+
features["decoder_input_ids"] = decoder_input_ids
|
602 |
+
|
603 |
+
return features
|
604 |
+
|
605 |
+
|
606 |
+
@dataclass
|
607 |
+
class DataCollatorForLanguageModeling(DataCollatorMixin):
|
608 |
+
"""
|
609 |
+
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
|
610 |
+
are not all of the same length.
|
611 |
+
|
612 |
+
Args:
|
613 |
+
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
614 |
+
The tokenizer used for encoding the data.
|
615 |
+
mlm (`bool`, *optional*, defaults to `True`):
|
616 |
+
Whether or not to use masked language modeling. If set to `False`, the labels are the same as the inputs
|
617 |
+
with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for non-masked
|
618 |
+
tokens and the value to predict for the masked token.
|
619 |
+
mlm_probability (`float`, *optional*, defaults to 0.15):
|
620 |
+
The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`.
|
621 |
+
pad_to_multiple_of (`int`, *optional*):
|
622 |
+
If set will pad the sequence to a multiple of the provided value.
|
623 |
+
return_tensors (`str`):
|
624 |
+
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
625 |
+
|
626 |
+
<Tip>
|
627 |
+
|
628 |
+
For best performance, this data collator should be used with a dataset having items that are dictionaries or
|
629 |
+
BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a
|
630 |
+
[`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`.
|
631 |
+
|
632 |
+
</Tip>"""
|
633 |
+
|
634 |
+
tokenizer: PreTrainedTokenizerBase
|
635 |
+
mlm: bool = True
|
636 |
+
mlm_probability: float = 0.15
|
637 |
+
pad_to_multiple_of: Optional[int] = None
|
638 |
+
tf_experimental_compile: bool = False
|
639 |
+
return_tensors: str = "pt"
|
640 |
+
|
641 |
+
def __post_init__(self):
|
642 |
+
if self.mlm and self.tokenizer.mask_token is None:
|
643 |
+
raise ValueError(
|
644 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
645 |
+
"You should pass `mlm=False` to train on causal language modeling instead."
|
646 |
+
)
|
647 |
+
if self.tf_experimental_compile:
|
648 |
+
import tensorflow as tf
|
649 |
+
|
650 |
+
self.tf_mask_tokens = tf.function(self.tf_mask_tokens, jit_compile=True)
|
651 |
+
|
652 |
+
@staticmethod
|
653 |
+
def tf_bernoulli(shape, probability):
|
654 |
+
import tensorflow as tf
|
655 |
+
|
656 |
+
prob_matrix = tf.fill(shape, probability)
|
657 |
+
return tf.cast(prob_matrix - tf.random.uniform(shape, 0, 1) >= 0, tf.bool)
|
658 |
+
|
659 |
+
def tf_mask_tokens(
|
660 |
+
self, inputs: Any, vocab_size, mask_token_id, special_tokens_mask: Optional[Any] = None
|
661 |
+
) -> Tuple[Any, Any]:
|
662 |
+
"""
|
663 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
664 |
+
"""
|
665 |
+
import tensorflow as tf
|
666 |
+
|
667 |
+
mask_token_id = tf.cast(mask_token_id, inputs.dtype)
|
668 |
+
|
669 |
+
input_shape = tf.shape(inputs)
|
670 |
+
# 1 for a special token, 0 for a normal token in the special tokens mask
|
671 |
+
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
672 |
+
masked_indices = self.tf_bernoulli(input_shape, self.mlm_probability) & ~special_tokens_mask
|
673 |
+
# Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
|
674 |
+
labels = tf.where(masked_indices, inputs, -100)
|
675 |
+
|
676 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
677 |
+
indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
|
678 |
+
|
679 |
+
inputs = tf.where(indices_replaced, mask_token_id, inputs)
|
680 |
+
|
681 |
+
# 10% of the time, we replace masked input tokens with random word
|
682 |
+
indices_random = self.tf_bernoulli(input_shape, 0.1) & masked_indices & ~indices_replaced
|
683 |
+
random_words = tf.random.uniform(input_shape, maxval=vocab_size, dtype=inputs.dtype)
|
684 |
+
|
685 |
+
inputs = tf.where(indices_random, random_words, inputs)
|
686 |
+
|
687 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
688 |
+
return inputs, labels
|
689 |
+
|
690 |
+
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
691 |
+
import tensorflow as tf
|
692 |
+
|
693 |
+
# Handle dict or lists with proper padding and conversion to tensor.
|
694 |
+
if isinstance(examples[0], Mapping):
|
695 |
+
batch = self.tokenizer.pad(examples, return_tensors="tf", pad_to_multiple_of=self.pad_to_multiple_of)
|
696 |
+
else:
|
697 |
+
batch = {
|
698 |
+
"input_ids": _tf_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
699 |
+
}
|
700 |
+
|
701 |
+
# If special token mask has been preprocessed, pop it from the dict.
|
702 |
+
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
703 |
+
if self.mlm:
|
704 |
+
if special_tokens_mask is None:
|
705 |
+
special_tokens_mask = [
|
706 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
|
707 |
+
for val in batch["input_ids"].numpy().tolist()
|
708 |
+
]
|
709 |
+
# Cannot directly create as bool
|
710 |
+
special_tokens_mask = tf.cast(tf.convert_to_tensor(special_tokens_mask, dtype=tf.int64), tf.bool)
|
711 |
+
else:
|
712 |
+
special_tokens_mask = tf.cast(special_tokens_mask, tf.bool)
|
713 |
+
batch["input_ids"], batch["labels"] = self.tf_mask_tokens(
|
714 |
+
tf.cast(batch["input_ids"], tf.int64),
|
715 |
+
special_tokens_mask=special_tokens_mask,
|
716 |
+
mask_token_id=self.tokenizer.mask_token_id,
|
717 |
+
vocab_size=len(self.tokenizer),
|
718 |
+
)
|
719 |
+
else:
|
720 |
+
labels = batch["input_ids"]
|
721 |
+
if self.tokenizer.pad_token_id is not None:
|
722 |
+
# Replace self.tokenizer.pad_token_id with -100
|
723 |
+
labels = tf.where(labels == self.tokenizer.pad_token_id, -100, labels)
|
724 |
+
else:
|
725 |
+
labels = tf.identity(labels) # Makes a copy, just in case
|
726 |
+
batch["labels"] = labels
|
727 |
+
return batch
|
728 |
+
|
729 |
+
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
730 |
+
# Handle dict or lists with proper padding and conversion to tensor.
|
731 |
+
if isinstance(examples[0], Mapping):
|
732 |
+
batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
|
733 |
+
else:
|
734 |
+
batch = {
|
735 |
+
"input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
736 |
+
}
|
737 |
+
|
738 |
+
# If special token mask has been preprocessed, pop it from the dict.
|
739 |
+
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
740 |
+
if self.mlm:
|
741 |
+
batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
|
742 |
+
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
743 |
+
)
|
744 |
+
else:
|
745 |
+
labels = batch["input_ids"].clone()
|
746 |
+
if self.tokenizer.pad_token_id is not None:
|
747 |
+
labels[labels == self.tokenizer.pad_token_id] = -100
|
748 |
+
batch["labels"] = labels
|
749 |
+
return batch
|
750 |
+
|
751 |
+
def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
|
752 |
+
"""
|
753 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
754 |
+
"""
|
755 |
+
import torch
|
756 |
+
|
757 |
+
labels = inputs.clone()
|
758 |
+
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
759 |
+
probability_matrix = torch.full(labels.shape, self.mlm_probability)
|
760 |
+
if special_tokens_mask is None:
|
761 |
+
special_tokens_mask = [
|
762 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
763 |
+
]
|
764 |
+
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
|
765 |
+
else:
|
766 |
+
special_tokens_mask = special_tokens_mask.bool()
|
767 |
+
|
768 |
+
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
|
769 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
770 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
771 |
+
|
772 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
773 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
774 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
775 |
+
|
776 |
+
# 10% of the time, we replace masked input tokens with random word
|
777 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
778 |
+
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
779 |
+
inputs[indices_random] = random_words[indices_random]
|
780 |
+
|
781 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
782 |
+
return inputs, labels
|
783 |
+
|
784 |
+
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
785 |
+
# Handle dict or lists with proper padding and conversion to tensor.
|
786 |
+
if isinstance(examples[0], Mapping):
|
787 |
+
batch = self.tokenizer.pad(examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of)
|
788 |
+
else:
|
789 |
+
batch = {
|
790 |
+
"input_ids": _numpy_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
791 |
+
}
|
792 |
+
|
793 |
+
# If special token mask has been preprocessed, pop it from the dict.
|
794 |
+
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
795 |
+
if self.mlm:
|
796 |
+
batch["input_ids"], batch["labels"] = self.numpy_mask_tokens(
|
797 |
+
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
798 |
+
)
|
799 |
+
else:
|
800 |
+
labels = np.copy(batch["input_ids"])
|
801 |
+
if self.tokenizer.pad_token_id is not None:
|
802 |
+
labels[labels == self.tokenizer.pad_token_id] = -100
|
803 |
+
batch["labels"] = labels
|
804 |
+
return batch
|
805 |
+
|
806 |
+
def numpy_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
|
807 |
+
"""
|
808 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
809 |
+
"""
|
810 |
+
labels = np.copy(inputs)
|
811 |
+
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
812 |
+
probability_matrix = np.full(labels.shape, self.mlm_probability)
|
813 |
+
if special_tokens_mask is None:
|
814 |
+
special_tokens_mask = [
|
815 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
816 |
+
]
|
817 |
+
special_tokens_mask = np.array(special_tokens_mask, dtype=bool)
|
818 |
+
else:
|
819 |
+
special_tokens_mask = special_tokens_mask.astype(bool)
|
820 |
+
|
821 |
+
probability_matrix[special_tokens_mask] = 0
|
822 |
+
# Numpy doesn't have bernoulli, so we use a binomial with 1 trial
|
823 |
+
masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
|
824 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
825 |
+
|
826 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
827 |
+
indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
|
828 |
+
inputs[indices_replaced] = self.tokenizer.mask_token_id
|
829 |
+
|
830 |
+
# 10% of the time, we replace masked input tokens with random word
|
831 |
+
# indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
832 |
+
indices_random = (
|
833 |
+
np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
|
834 |
+
)
|
835 |
+
random_words = np.random.randint(
|
836 |
+
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
|
837 |
+
)
|
838 |
+
inputs[indices_random] = random_words
|
839 |
+
|
840 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
841 |
+
return inputs, labels
|
842 |
+
|
843 |
+
|
844 |
+
@dataclass
|
845 |
+
class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
|
846 |
+
"""
|
847 |
+
Data collator used for language modeling that masks entire words.
|
848 |
+
|
849 |
+
- collates batches of tensors, honoring their tokenizer's pad_token
|
850 |
+
- preprocesses batches for masked language modeling
|
851 |
+
|
852 |
+
<Tip>
|
853 |
+
|
854 |
+
This collator relies on details of the implementation of subword tokenization by [`BertTokenizer`], specifically
|
855 |
+
that subword tokens are prefixed with *##*. For tokenizers that do not adhere to this scheme, this collator will
|
856 |
+
produce an output that is roughly equivalent to [`.DataCollatorForLanguageModeling`].
|
857 |
+
|
858 |
+
</Tip>"""
|
859 |
+
|
860 |
+
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
861 |
+
if isinstance(examples[0], Mapping):
|
862 |
+
input_ids = [e["input_ids"] for e in examples]
|
863 |
+
else:
|
864 |
+
input_ids = examples
|
865 |
+
examples = [{"input_ids": e} for e in examples]
|
866 |
+
|
867 |
+
batch_input = _torch_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
868 |
+
|
869 |
+
mask_labels = []
|
870 |
+
for e in examples:
|
871 |
+
ref_tokens = []
|
872 |
+
for id in tolist(e["input_ids"]):
|
873 |
+
token = self.tokenizer._convert_id_to_token(id)
|
874 |
+
ref_tokens.append(token)
|
875 |
+
|
876 |
+
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
|
877 |
+
if "chinese_ref" in e:
|
878 |
+
ref_pos = tolist(e["chinese_ref"])
|
879 |
+
len_seq = len(e["input_ids"])
|
880 |
+
for i in range(len_seq):
|
881 |
+
if i in ref_pos:
|
882 |
+
ref_tokens[i] = "##" + ref_tokens[i]
|
883 |
+
mask_labels.append(self._whole_word_mask(ref_tokens))
|
884 |
+
batch_mask = _torch_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
885 |
+
inputs, labels = self.torch_mask_tokens(batch_input, batch_mask)
|
886 |
+
return {"input_ids": inputs, "labels": labels}
|
887 |
+
|
888 |
+
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
889 |
+
import tensorflow as tf
|
890 |
+
|
891 |
+
if isinstance(examples[0], Mapping):
|
892 |
+
input_ids = [e["input_ids"] for e in examples]
|
893 |
+
else:
|
894 |
+
input_ids = examples
|
895 |
+
examples = [{"input_ids": e} for e in examples]
|
896 |
+
|
897 |
+
batch_input = _tf_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
898 |
+
|
899 |
+
mask_labels = []
|
900 |
+
for e in examples:
|
901 |
+
ref_tokens = []
|
902 |
+
for id in tolist(e["input_ids"]):
|
903 |
+
token = self.tokenizer._convert_id_to_token(id)
|
904 |
+
ref_tokens.append(token)
|
905 |
+
|
906 |
+
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
|
907 |
+
if "chinese_ref" in e:
|
908 |
+
ref_pos = tolist(e["chinese_ref"])
|
909 |
+
len_seq = len(e["input_ids"])
|
910 |
+
for i in range(len_seq):
|
911 |
+
if i in ref_pos:
|
912 |
+
ref_tokens[i] = "##" + ref_tokens[i]
|
913 |
+
mask_labels.append(self._whole_word_mask(ref_tokens))
|
914 |
+
batch_mask = _tf_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
915 |
+
inputs, labels = self.tf_mask_tokens(tf.cast(batch_input, tf.int64), batch_mask)
|
916 |
+
return {"input_ids": inputs, "labels": labels}
|
917 |
+
|
918 |
+
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
919 |
+
if isinstance(examples[0], Mapping):
|
920 |
+
input_ids = [e["input_ids"] for e in examples]
|
921 |
+
else:
|
922 |
+
input_ids = examples
|
923 |
+
examples = [{"input_ids": e} for e in examples]
|
924 |
+
|
925 |
+
batch_input = _numpy_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
926 |
+
|
927 |
+
mask_labels = []
|
928 |
+
for e in examples:
|
929 |
+
ref_tokens = []
|
930 |
+
for id in tolist(e["input_ids"]):
|
931 |
+
token = self.tokenizer._convert_id_to_token(id)
|
932 |
+
ref_tokens.append(token)
|
933 |
+
|
934 |
+
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
|
935 |
+
if "chinese_ref" in e:
|
936 |
+
ref_pos = tolist(e["chinese_ref"])
|
937 |
+
len_seq = len(e["input_ids"])
|
938 |
+
for i in range(len_seq):
|
939 |
+
if i in ref_pos:
|
940 |
+
ref_tokens[i] = "##" + ref_tokens[i]
|
941 |
+
mask_labels.append(self._whole_word_mask(ref_tokens))
|
942 |
+
batch_mask = _numpy_collate_batch(mask_labels, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
|
943 |
+
inputs, labels = self.numpy_mask_tokens(batch_input, batch_mask)
|
944 |
+
return {"input_ids": inputs, "labels": labels}
|
945 |
+
|
946 |
+
def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
|
947 |
+
"""
|
948 |
+
Get 0/1 labels for masked tokens with whole word mask proxy
|
949 |
+
"""
|
950 |
+
if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
|
951 |
+
warnings.warn(
|
952 |
+
"DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. "
|
953 |
+
"Please refer to the documentation for more information."
|
954 |
+
)
|
955 |
+
|
956 |
+
cand_indexes = []
|
957 |
+
for i, token in enumerate(input_tokens):
|
958 |
+
if token == "[CLS]" or token == "[SEP]":
|
959 |
+
continue
|
960 |
+
|
961 |
+
if len(cand_indexes) >= 1 and token.startswith("##"):
|
962 |
+
cand_indexes[-1].append(i)
|
963 |
+
else:
|
964 |
+
cand_indexes.append([i])
|
965 |
+
|
966 |
+
random.shuffle(cand_indexes)
|
967 |
+
num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
|
968 |
+
masked_lms = []
|
969 |
+
covered_indexes = set()
|
970 |
+
for index_set in cand_indexes:
|
971 |
+
if len(masked_lms) >= num_to_predict:
|
972 |
+
break
|
973 |
+
# If adding a whole-word mask would exceed the maximum number of
|
974 |
+
# predictions, then just skip this candidate.
|
975 |
+
if len(masked_lms) + len(index_set) > num_to_predict:
|
976 |
+
continue
|
977 |
+
is_any_index_covered = False
|
978 |
+
for index in index_set:
|
979 |
+
if index in covered_indexes:
|
980 |
+
is_any_index_covered = True
|
981 |
+
break
|
982 |
+
if is_any_index_covered:
|
983 |
+
continue
|
984 |
+
for index in index_set:
|
985 |
+
covered_indexes.add(index)
|
986 |
+
masked_lms.append(index)
|
987 |
+
|
988 |
+
if len(covered_indexes) != len(masked_lms):
|
989 |
+
raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
|
990 |
+
mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
|
991 |
+
return mask_labels
|
992 |
+
|
993 |
+
def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
|
994 |
+
"""
|
995 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
|
996 |
+
'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
|
997 |
+
"""
|
998 |
+
import torch
|
999 |
+
|
1000 |
+
if self.tokenizer.mask_token is None:
|
1001 |
+
raise ValueError(
|
1002 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
|
1003 |
+
" --mlm flag if you want to use this tokenizer."
|
1004 |
+
)
|
1005 |
+
labels = inputs.clone()
|
1006 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
1007 |
+
|
1008 |
+
probability_matrix = mask_labels
|
1009 |
+
|
1010 |
+
special_tokens_mask = [
|
1011 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
1012 |
+
]
|
1013 |
+
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
1014 |
+
if self.tokenizer._pad_token is not None:
|
1015 |
+
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
1016 |
+
probability_matrix.masked_fill_(padding_mask, value=0.0)
|
1017 |
+
|
1018 |
+
masked_indices = probability_matrix.bool()
|
1019 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
1020 |
+
|
1021 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
1022 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
1023 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
1024 |
+
|
1025 |
+
# 10% of the time, we replace masked input tokens with random word
|
1026 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
1027 |
+
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
1028 |
+
inputs[indices_random] = random_words[indices_random]
|
1029 |
+
|
1030 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
1031 |
+
return inputs, labels
|
1032 |
+
|
1033 |
+
def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
|
1034 |
+
"""
|
1035 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
|
1036 |
+
'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
|
1037 |
+
"""
|
1038 |
+
import tensorflow as tf
|
1039 |
+
|
1040 |
+
input_shape = tf.shape(inputs)
|
1041 |
+
if self.tokenizer.mask_token is None:
|
1042 |
+
raise ValueError(
|
1043 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
|
1044 |
+
" --mlm flag if you want to use this tokenizer."
|
1045 |
+
)
|
1046 |
+
labels = tf.identity(inputs)
|
1047 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
1048 |
+
|
1049 |
+
masked_indices = tf.cast(mask_labels, tf.bool)
|
1050 |
+
|
1051 |
+
special_tokens_mask = [
|
1052 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels
|
1053 |
+
]
|
1054 |
+
masked_indices = masked_indices & ~tf.cast(special_tokens_mask, dtype=tf.bool)
|
1055 |
+
if self.tokenizer._pad_token is not None:
|
1056 |
+
padding_mask = inputs == self.tokenizer.pad_token_id
|
1057 |
+
masked_indices = masked_indices & ~padding_mask
|
1058 |
+
|
1059 |
+
# Replace unmasked indices with -100 in the labels since we only compute loss on masked tokens
|
1060 |
+
labels = tf.where(masked_indices, inputs, -100)
|
1061 |
+
|
1062 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
1063 |
+
indices_replaced = self.tf_bernoulli(input_shape, 0.8) & masked_indices
|
1064 |
+
|
1065 |
+
inputs = tf.where(indices_replaced, self.tokenizer.mask_token_id, inputs)
|
1066 |
+
|
1067 |
+
# 10% of the time, we replace masked input tokens with random word
|
1068 |
+
indices_random = self.tf_bernoulli(input_shape, 0.5) & masked_indices & ~indices_replaced
|
1069 |
+
random_words = tf.random.uniform(input_shape, maxval=len(self.tokenizer), dtype=tf.int64)
|
1070 |
+
inputs = tf.where(indices_random, random_words, inputs)
|
1071 |
+
|
1072 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
1073 |
+
return inputs, labels
|
1074 |
+
|
1075 |
+
def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
|
1076 |
+
"""
|
1077 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
|
1078 |
+
'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
|
1079 |
+
"""
|
1080 |
+
if self.tokenizer.mask_token is None:
|
1081 |
+
raise ValueError(
|
1082 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
|
1083 |
+
" --mlm flag if you want to use this tokenizer."
|
1084 |
+
)
|
1085 |
+
labels = np.copy(inputs)
|
1086 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
1087 |
+
|
1088 |
+
masked_indices = mask_labels.astype(bool)
|
1089 |
+
|
1090 |
+
special_tokens_mask = [
|
1091 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
1092 |
+
]
|
1093 |
+
masked_indices[np.array(special_tokens_mask, dtype=bool)] = 0
|
1094 |
+
if self.tokenizer._pad_token is not None:
|
1095 |
+
padding_mask = labels == self.tokenizer.pad_token_id
|
1096 |
+
masked_indices[padding_mask] = 0
|
1097 |
+
|
1098 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
1099 |
+
|
1100 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
1101 |
+
indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
|
1102 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
1103 |
+
|
1104 |
+
# 10% of the time, we replace masked input tokens with random word
|
1105 |
+
# indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
1106 |
+
indices_random = (
|
1107 |
+
np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
|
1108 |
+
)
|
1109 |
+
random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
|
1110 |
+
inputs[indices_random] = random_words[indices_random]
|
1111 |
+
|
1112 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
1113 |
+
return inputs, labels
|
1114 |
+
|
1115 |
+
|
1116 |
+
@dataclass
|
1117 |
+
class DataCollatorForSOP(DataCollatorForLanguageModeling):
|
1118 |
+
"""
|
1119 |
+
Data collator used for sentence order prediction task.
|
1120 |
+
|
1121 |
+
- collates batches of tensors, honoring their tokenizer's pad_token
|
1122 |
+
- preprocesses batches for both masked language modeling and sentence order prediction
|
1123 |
+
"""
|
1124 |
+
|
1125 |
+
def __init__(self, *args, **kwargs):
|
1126 |
+
warnings.warn(
|
1127 |
+
"DataCollatorForSOP is deprecated and will be removed in a future version, you can now use "
|
1128 |
+
"DataCollatorForLanguageModeling instead.",
|
1129 |
+
FutureWarning,
|
1130 |
+
)
|
1131 |
+
|
1132 |
+
def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
1133 |
+
import torch
|
1134 |
+
from torch.nn.utils.rnn import pad_sequence
|
1135 |
+
|
1136 |
+
input_ids = [example["input_ids"] for example in examples]
|
1137 |
+
input_ids = _torch_collate_batch(input_ids, self.tokenizer)
|
1138 |
+
input_ids, labels, attention_mask = self.mask_tokens(input_ids)
|
1139 |
+
|
1140 |
+
token_type_ids = [example["token_type_ids"] for example in examples]
|
1141 |
+
# size of segment_ids varied because randomness, padding zero to the end as the original implementation
|
1142 |
+
token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
1143 |
+
|
1144 |
+
sop_label_list = [example["sentence_order_label"] for example in examples]
|
1145 |
+
sentence_order_label = torch.stack(sop_label_list)
|
1146 |
+
|
1147 |
+
return {
|
1148 |
+
"input_ids": input_ids,
|
1149 |
+
"labels": labels,
|
1150 |
+
"attention_mask": attention_mask,
|
1151 |
+
"token_type_ids": token_type_ids,
|
1152 |
+
"sentence_order_label": sentence_order_label,
|
1153 |
+
}
|
1154 |
+
|
1155 |
+
def mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any]:
|
1156 |
+
"""
|
1157 |
+
Prepare masked tokens inputs/labels/attention_mask for masked language modeling: 80% MASK, 10% random, 10%
|
1158 |
+
original. N-gram not applied yet.
|
1159 |
+
"""
|
1160 |
+
import torch
|
1161 |
+
|
1162 |
+
if self.tokenizer.mask_token is None:
|
1163 |
+
raise ValueError(
|
1164 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
|
1165 |
+
" --mlm flag if you want to use this tokenizer."
|
1166 |
+
)
|
1167 |
+
|
1168 |
+
labels = inputs.clone()
|
1169 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
1170 |
+
probability_matrix = torch.full(labels.shape, self.mlm_probability)
|
1171 |
+
special_tokens_mask = [
|
1172 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
1173 |
+
]
|
1174 |
+
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
1175 |
+
if self.tokenizer._pad_token is not None:
|
1176 |
+
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
1177 |
+
probability_matrix.masked_fill_(padding_mask, value=0.0)
|
1178 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
1179 |
+
# probability be `1` (masked), however in albert model attention mask `0` means masked, revert the value
|
1180 |
+
attention_mask = (~masked_indices).float()
|
1181 |
+
if self.tokenizer._pad_token is not None:
|
1182 |
+
attention_padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
1183 |
+
attention_mask.masked_fill_(attention_padding_mask, value=1.0)
|
1184 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens, -100 is default for CE compute
|
1185 |
+
|
1186 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
1187 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
1188 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
1189 |
+
|
1190 |
+
# 10% of the time, we replace masked input tokens with random word
|
1191 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
1192 |
+
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
|
1193 |
+
inputs[indices_random] = random_words[indices_random]
|
1194 |
+
|
1195 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
1196 |
+
return inputs, labels, attention_mask
|
1197 |
+
|
1198 |
+
|
1199 |
+
@dataclass
|
1200 |
+
class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
|
1201 |
+
"""
|
1202 |
+
Data collator used for permutation language modeling.
|
1203 |
+
|
1204 |
+
- collates batches of tensors, honoring their tokenizer's pad_token
|
1205 |
+
- preprocesses batches for permutation language modeling with procedures specific to XLNet
|
1206 |
+
"""
|
1207 |
+
|
1208 |
+
tokenizer: PreTrainedTokenizerBase
|
1209 |
+
plm_probability: float = 1 / 6
|
1210 |
+
max_span_length: int = 5 # maximum length of a span of masked tokens
|
1211 |
+
return_tensors: str = "pt"
|
1212 |
+
|
1213 |
+
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
1214 |
+
if isinstance(examples[0], Mapping):
|
1215 |
+
examples = [e["input_ids"] for e in examples]
|
1216 |
+
batch = _torch_collate_batch(examples, self.tokenizer)
|
1217 |
+
inputs, perm_mask, target_mapping, labels = self.torch_mask_tokens(batch)
|
1218 |
+
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
1219 |
+
|
1220 |
+
def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
1221 |
+
if isinstance(examples[0], Mapping):
|
1222 |
+
examples = [e["input_ids"] for e in examples]
|
1223 |
+
batch = _tf_collate_batch(examples, self.tokenizer)
|
1224 |
+
inputs, perm_mask, target_mapping, labels = self.tf_mask_tokens(batch)
|
1225 |
+
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
1226 |
+
|
1227 |
+
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
1228 |
+
if isinstance(examples[0], Mapping):
|
1229 |
+
examples = [e["input_ids"] for e in examples]
|
1230 |
+
batch = _numpy_collate_batch(examples, self.tokenizer)
|
1231 |
+
inputs, perm_mask, target_mapping, labels = self.numpy_mask_tokens(batch)
|
1232 |
+
return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
|
1233 |
+
|
1234 |
+
def torch_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
|
1235 |
+
"""
|
1236 |
+
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
|
1237 |
+
|
1238 |
+
0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
1239 |
+
1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
1240 |
+
2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
|
1241 |
+
masked
|
1242 |
+
3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
|
1243 |
+
span_length]` and mask tokens `start_index:start_index + span_length`
|
1244 |
+
4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
|
1245 |
+
sequence to be processed), repeat from Step 1.
|
1246 |
+
"""
|
1247 |
+
import torch
|
1248 |
+
|
1249 |
+
if self.tokenizer.mask_token is None:
|
1250 |
+
raise ValueError(
|
1251 |
+
"This tokenizer does not have a mask token which is necessary for permutation language modeling."
|
1252 |
+
" Please add a mask token if you want to use this tokenizer."
|
1253 |
+
)
|
1254 |
+
|
1255 |
+
if inputs.size(1) % 2 != 0:
|
1256 |
+
raise ValueError(
|
1257 |
+
"This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
|
1258 |
+
" relevant comments in source code for details."
|
1259 |
+
)
|
1260 |
+
|
1261 |
+
labels = inputs.clone()
|
1262 |
+
# Creating the mask and target_mapping tensors
|
1263 |
+
masked_indices = torch.full(labels.shape, 0, dtype=torch.bool)
|
1264 |
+
target_mapping = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
|
1265 |
+
|
1266 |
+
for i in range(labels.size(0)):
|
1267 |
+
# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
1268 |
+
cur_len = 0
|
1269 |
+
max_len = labels.size(1)
|
1270 |
+
|
1271 |
+
while cur_len < max_len:
|
1272 |
+
# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
1273 |
+
span_length = torch.randint(1, self.max_span_length + 1, (1,)).item()
|
1274 |
+
# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
|
1275 |
+
context_length = int(span_length / self.plm_probability)
|
1276 |
+
# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
|
1277 |
+
start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item()
|
1278 |
+
masked_indices[i, start_index : start_index + span_length] = 1
|
1279 |
+
# Set `cur_len = cur_len + context_length`
|
1280 |
+
cur_len += context_length
|
1281 |
+
|
1282 |
+
# Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
|
1283 |
+
# the i-th predict corresponds to the i-th token.
|
1284 |
+
target_mapping[i] = torch.eye(labels.size(1))
|
1285 |
+
|
1286 |
+
special_tokens_mask = torch.tensor(
|
1287 |
+
[self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
|
1288 |
+
dtype=torch.bool,
|
1289 |
+
)
|
1290 |
+
masked_indices.masked_fill_(special_tokens_mask, value=0.0)
|
1291 |
+
if self.tokenizer._pad_token is not None:
|
1292 |
+
padding_mask = labels.eq(self.tokenizer.pad_token_id)
|
1293 |
+
masked_indices.masked_fill_(padding_mask, value=0.0)
|
1294 |
+
|
1295 |
+
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
|
1296 |
+
non_func_mask = ~(padding_mask | special_tokens_mask)
|
1297 |
+
|
1298 |
+
inputs[masked_indices] = self.tokenizer.mask_token_id
|
1299 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
1300 |
+
|
1301 |
+
perm_mask = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
|
1302 |
+
|
1303 |
+
for i in range(labels.size(0)):
|
1304 |
+
# Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
|
1305 |
+
# determine which tokens a given token can attend to (encoded in `perm_mask`).
|
1306 |
+
# Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
|
1307 |
+
# (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
|
1308 |
+
# we assume that reused length is half of sequence length and permutation length is equal to reused length.
|
1309 |
+
# This requires that the sequence length be even.
|
1310 |
+
|
1311 |
+
# Create a linear factorisation order
|
1312 |
+
perm_index = torch.arange(labels.size(1))
|
1313 |
+
# Split this into two halves, assuming that half the sequence is reused each time
|
1314 |
+
perm_index = perm_index.reshape((-1, labels.size(1) // 2)).transpose(0, 1)
|
1315 |
+
# Permute the two halves such that they do not cross over
|
1316 |
+
perm_index = perm_index[torch.randperm(labels.size(1) // 2)]
|
1317 |
+
# Flatten this out into the desired permuted factorisation order
|
1318 |
+
perm_index = torch.flatten(perm_index.transpose(0, 1))
|
1319 |
+
# Set the permutation indices of non-masked (non-functional) tokens to the
|
1320 |
+
# smallest index (-1) so that:
|
1321 |
+
# (1) They can be seen by all other positions
|
1322 |
+
# (2) They cannot see masked positions, so there won't be information leak
|
1323 |
+
perm_index.masked_fill_(~masked_indices[i] & non_func_mask[i], -1)
|
1324 |
+
# The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
|
1325 |
+
# 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
|
1326 |
+
# 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
|
1327 |
+
perm_mask[i] = (
|
1328 |
+
perm_index.reshape((labels.size(1), 1)) <= perm_index.reshape((1, labels.size(1)))
|
1329 |
+
) & masked_indices[i]
|
1330 |
+
|
1331 |
+
return inputs.long(), perm_mask, target_mapping, labels.long()
|
1332 |
+
|
1333 |
+
def tf_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
|
1334 |
+
"""
|
1335 |
+
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
|
1336 |
+
|
1337 |
+
0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
1338 |
+
1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
1339 |
+
2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
|
1340 |
+
masked
|
1341 |
+
3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
|
1342 |
+
span_length]` and mask tokens `start_index:start_index + span_length`
|
1343 |
+
4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
|
1344 |
+
sequence to be processed), repeat from Step 1.
|
1345 |
+
"""
|
1346 |
+
import tensorflow as tf
|
1347 |
+
|
1348 |
+
if self.tokenizer.mask_token is None:
|
1349 |
+
raise ValueError(
|
1350 |
+
"This tokenizer does not have a mask token which is necessary for permutation language modeling."
|
1351 |
+
" Please add a mask token if you want to use this tokenizer."
|
1352 |
+
)
|
1353 |
+
|
1354 |
+
if tf.shape(inputs)[1] % 2 != 0:
|
1355 |
+
raise ValueError(
|
1356 |
+
"This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
|
1357 |
+
" relevant comments in source code for details."
|
1358 |
+
)
|
1359 |
+
|
1360 |
+
labels = tf.identity(inputs)
|
1361 |
+
# Creating the mask and target_mapping tensors
|
1362 |
+
masked_indices = np.full(labels.shape.as_list(), 0, dtype=bool)
|
1363 |
+
labels_shape = tf.shape(labels)
|
1364 |
+
target_mapping = np.zeros((labels_shape[0], labels_shape[1], labels_shape[1]), dtype=np.float32)
|
1365 |
+
|
1366 |
+
for i in range(len(labels)):
|
1367 |
+
# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
1368 |
+
cur_len = 0
|
1369 |
+
max_len = tf.shape(labels)[1]
|
1370 |
+
|
1371 |
+
while cur_len < max_len:
|
1372 |
+
# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
1373 |
+
span_length = randint(1, self.max_span_length + 1)
|
1374 |
+
# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
|
1375 |
+
context_length = int(span_length / self.plm_probability)
|
1376 |
+
# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
|
1377 |
+
start_index = cur_len + randint(0, context_length - span_length + 1)
|
1378 |
+
masked_indices[i, start_index : start_index + span_length] = 1
|
1379 |
+
# Set `cur_len = cur_len + context_length`
|
1380 |
+
cur_len += context_length
|
1381 |
+
|
1382 |
+
# Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
|
1383 |
+
# the i-th predict corresponds to the i-th token.
|
1384 |
+
target_mapping[i] = np.eye(labels_shape[1])
|
1385 |
+
masked_indices = tf.cast(tf.convert_to_tensor(masked_indices), dtype=tf.bool)
|
1386 |
+
target_mapping = tf.convert_to_tensor(target_mapping)
|
1387 |
+
special_tokens_mask = tf.convert_to_tensor(
|
1388 |
+
[
|
1389 |
+
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True)
|
1390 |
+
for val in labels.numpy().tolist()
|
1391 |
+
],
|
1392 |
+
)
|
1393 |
+
special_tokens_mask = tf.cast(special_tokens_mask, dtype=tf.bool)
|
1394 |
+
masked_indices = masked_indices & ~special_tokens_mask
|
1395 |
+
if self.tokenizer._pad_token is not None:
|
1396 |
+
padding_mask = labels == self.tokenizer.pad_token_id
|
1397 |
+
masked_indices = masked_indices & ~padding_mask
|
1398 |
+
|
1399 |
+
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
|
1400 |
+
non_func_mask = ~(padding_mask | special_tokens_mask)
|
1401 |
+
|
1402 |
+
inputs = tf.where(masked_indices, self.tokenizer.mask_token_id, inputs)
|
1403 |
+
labels = tf.where(masked_indices, labels, -100) # We only compute loss on masked tokens
|
1404 |
+
|
1405 |
+
perm_mask = []
|
1406 |
+
|
1407 |
+
for i in range(len(labels)):
|
1408 |
+
# Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
|
1409 |
+
# determine which tokens a given token can attend to (encoded in `perm_mask`).
|
1410 |
+
# Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
|
1411 |
+
# (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
|
1412 |
+
# we assume that reused length is half of sequence length and permutation length is equal to reused length.
|
1413 |
+
# This requires that the sequence length be even.
|
1414 |
+
|
1415 |
+
# Create a linear factorisation order
|
1416 |
+
# tf.range is the equivalent of torch.arange
|
1417 |
+
perm_index = tf.range(labels_shape[1])
|
1418 |
+
# Split this into two halves, assuming that half the sequence is reused each time
|
1419 |
+
perm_index = tf.transpose(tf.reshape(perm_index, (-1, labels_shape[1] // 2)))
|
1420 |
+
# Permute the two halves such that they do not cross over
|
1421 |
+
perm_index = tf.random.shuffle(perm_index) # Shuffles along the first dimension
|
1422 |
+
# Flatten this out into the desired permuted factorisation order
|
1423 |
+
perm_index = tf.reshape(tf.transpose(perm_index), (-1,))
|
1424 |
+
# Set the permutation indices of non-masked (non-functional) tokens to the
|
1425 |
+
# smallest index (-1) so that:
|
1426 |
+
# (1) They can be seen by all other positions
|
1427 |
+
# (2) They cannot see masked positions, so there won't be information leak
|
1428 |
+
perm_index = tf.where(~masked_indices[i] & non_func_mask[i], -1, perm_index)
|
1429 |
+
# The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
|
1430 |
+
# 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
|
1431 |
+
# 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
|
1432 |
+
perm_mask.append(
|
1433 |
+
(tf.reshape(perm_index, (labels_shape[1], 1)) <= tf.reshape(perm_index, (1, labels_shape[1])))
|
1434 |
+
& masked_indices[i]
|
1435 |
+
)
|
1436 |
+
perm_mask = tf.stack(perm_mask, axis=0)
|
1437 |
+
|
1438 |
+
return tf.cast(inputs, tf.int64), tf.cast(perm_mask, tf.float32), target_mapping, tf.cast(labels, tf.int64)
|
1439 |
+
|
1440 |
+
def numpy_mask_tokens(self, inputs: Any) -> Tuple[Any, Any, Any, Any]:
|
1441 |
+
"""
|
1442 |
+
The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
|
1443 |
+
|
1444 |
+
0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
1445 |
+
1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
1446 |
+
2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
|
1447 |
+
masked
|
1448 |
+
3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
|
1449 |
+
span_length]` and mask tokens `start_index:start_index + span_length`
|
1450 |
+
4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
|
1451 |
+
sequence to be processed), repeat from Step 1.
|
1452 |
+
"""
|
1453 |
+
if self.tokenizer.mask_token is None:
|
1454 |
+
raise ValueError(
|
1455 |
+
"This tokenizer does not have a mask token which is necessary for permutation language modeling."
|
1456 |
+
" Please add a mask token if you want to use this tokenizer."
|
1457 |
+
)
|
1458 |
+
|
1459 |
+
if inputs.shape[1] % 2 != 0:
|
1460 |
+
raise ValueError(
|
1461 |
+
"This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
|
1462 |
+
" relevant comments in source code for details."
|
1463 |
+
)
|
1464 |
+
|
1465 |
+
labels = np.copy(inputs)
|
1466 |
+
# Creating the mask and target_mapping tensors
|
1467 |
+
masked_indices = np.full(labels.shape, 0, dtype=bool)
|
1468 |
+
target_mapping = np.zeros((labels.shape[0], labels.shape[1], labels.shape[1]), dtype=np.float32)
|
1469 |
+
|
1470 |
+
for i in range(labels.shape[0]):
|
1471 |
+
# Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
|
1472 |
+
cur_len = 0
|
1473 |
+
max_len = labels.shape[1]
|
1474 |
+
|
1475 |
+
while cur_len < max_len:
|
1476 |
+
# Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
|
1477 |
+
span_length = randint(1, self.max_span_length + 1)
|
1478 |
+
# Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
|
1479 |
+
context_length = int(span_length / self.plm_probability)
|
1480 |
+
# Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
|
1481 |
+
start_index = cur_len + randint(0, context_length - span_length + 1)
|
1482 |
+
masked_indices[i, start_index : start_index + span_length] = 1
|
1483 |
+
# Set `cur_len = cur_len + context_length`
|
1484 |
+
cur_len += context_length
|
1485 |
+
|
1486 |
+
# Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
|
1487 |
+
# the i-th predict corresponds to the i-th token.
|
1488 |
+
target_mapping[i] = np.eye(labels.shape[1])
|
1489 |
+
|
1490 |
+
special_tokens_mask = np.array(
|
1491 |
+
[self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
|
1492 |
+
dtype=bool,
|
1493 |
+
)
|
1494 |
+
masked_indices[special_tokens_mask] = 0
|
1495 |
+
if self.tokenizer._pad_token is not None:
|
1496 |
+
padding_mask = labels == self.tokenizer.pad_token_id
|
1497 |
+
masked_indices[padding_mask] = 0.0
|
1498 |
+
|
1499 |
+
# Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
|
1500 |
+
non_func_mask = ~(padding_mask | special_tokens_mask)
|
1501 |
+
|
1502 |
+
inputs[masked_indices] = self.tokenizer.mask_token_id
|
1503 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
1504 |
+
|
1505 |
+
perm_mask = np.zeros((labels.shape[0], labels.shape[1], labels.shape[1]), dtype=np.float32)
|
1506 |
+
|
1507 |
+
for i in range(labels.shape[0]):
|
1508 |
+
# Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
|
1509 |
+
# determine which tokens a given token can attend to (encoded in `perm_mask`).
|
1510 |
+
# Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
|
1511 |
+
# (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
|
1512 |
+
# we assume that reused length is half of sequence length and permutation length is equal to reused length.
|
1513 |
+
# This requires that the sequence length be even.
|
1514 |
+
|
1515 |
+
# Create a linear factorisation order
|
1516 |
+
perm_index = np.arange(labels.shape[1])
|
1517 |
+
# Split this into two halves, assuming that half the sequence is reused each time
|
1518 |
+
perm_index = perm_index.reshape((-1, labels.shape[1] // 2)).T
|
1519 |
+
# Permute the two halves such that they do not cross over
|
1520 |
+
np.random.shuffle(perm_index)
|
1521 |
+
# Flatten this out into the desired permuted factorisation order
|
1522 |
+
perm_index = perm_index.T.flatten()
|
1523 |
+
# Set the permutation indices of non-masked (non-functional) tokens to the
|
1524 |
+
# smallest index (-1) so that:
|
1525 |
+
# (1) They can be seen by all other positions
|
1526 |
+
# (2) They cannot see masked positions, so there won't be information leak
|
1527 |
+
perm_index[~masked_indices[i] & non_func_mask[i]] = -1
|
1528 |
+
# The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
|
1529 |
+
# 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
|
1530 |
+
# 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
|
1531 |
+
perm_mask[i] = (
|
1532 |
+
perm_index.reshape((labels.shape[1], 1)) <= perm_index.reshape((1, labels.shape[1]))
|
1533 |
+
) & masked_indices[i]
|
1534 |
+
|
1535 |
+
return inputs.astype(np.int64), perm_mask, target_mapping, labels.astype(np.int64)
|
transformers_4_35_0/data/datasets/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .glue import GlueDataset, GlueDataTrainingArguments
|
16 |
+
from .language_modeling import (
|
17 |
+
LineByLineTextDataset,
|
18 |
+
LineByLineWithRefDataset,
|
19 |
+
LineByLineWithSOPTextDataset,
|
20 |
+
TextDataset,
|
21 |
+
TextDatasetForNextSentencePrediction,
|
22 |
+
)
|
23 |
+
from .squad import SquadDataset, SquadDataTrainingArguments
|
transformers_4_35_0/data/datasets/glue.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import time
|
17 |
+
import warnings
|
18 |
+
from dataclasses import dataclass, field
|
19 |
+
from enum import Enum
|
20 |
+
from typing import List, Optional, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from filelock import FileLock
|
24 |
+
from torch.utils.data import Dataset
|
25 |
+
|
26 |
+
from ...tokenization_utils_base import PreTrainedTokenizerBase
|
27 |
+
from ...utils import logging
|
28 |
+
from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors
|
29 |
+
from ..processors.utils import InputFeatures
|
30 |
+
|
31 |
+
|
32 |
+
logger = logging.get_logger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class GlueDataTrainingArguments:
|
37 |
+
"""
|
38 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
39 |
+
|
40 |
+
Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command
|
41 |
+
line.
|
42 |
+
"""
|
43 |
+
|
44 |
+
task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
|
45 |
+
data_dir: str = field(
|
46 |
+
metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."}
|
47 |
+
)
|
48 |
+
max_seq_length: int = field(
|
49 |
+
default=128,
|
50 |
+
metadata={
|
51 |
+
"help": (
|
52 |
+
"The maximum total input sequence length after tokenization. Sequences longer "
|
53 |
+
"than this will be truncated, sequences shorter will be padded."
|
54 |
+
)
|
55 |
+
},
|
56 |
+
)
|
57 |
+
overwrite_cache: bool = field(
|
58 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
59 |
+
)
|
60 |
+
|
61 |
+
def __post_init__(self):
|
62 |
+
self.task_name = self.task_name.lower()
|
63 |
+
|
64 |
+
|
65 |
+
class Split(Enum):
|
66 |
+
train = "train"
|
67 |
+
dev = "dev"
|
68 |
+
test = "test"
|
69 |
+
|
70 |
+
|
71 |
+
class GlueDataset(Dataset):
|
72 |
+
"""
|
73 |
+
This will be superseded by a framework-agnostic approach soon.
|
74 |
+
"""
|
75 |
+
|
76 |
+
args: GlueDataTrainingArguments
|
77 |
+
output_mode: str
|
78 |
+
features: List[InputFeatures]
|
79 |
+
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
args: GlueDataTrainingArguments,
|
83 |
+
tokenizer: PreTrainedTokenizerBase,
|
84 |
+
limit_length: Optional[int] = None,
|
85 |
+
mode: Union[str, Split] = Split.train,
|
86 |
+
cache_dir: Optional[str] = None,
|
87 |
+
):
|
88 |
+
warnings.warn(
|
89 |
+
"This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
|
90 |
+
"library. You can have a look at this example script for pointers: "
|
91 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py",
|
92 |
+
FutureWarning,
|
93 |
+
)
|
94 |
+
self.args = args
|
95 |
+
self.processor = glue_processors[args.task_name]()
|
96 |
+
self.output_mode = glue_output_modes[args.task_name]
|
97 |
+
if isinstance(mode, str):
|
98 |
+
try:
|
99 |
+
mode = Split[mode]
|
100 |
+
except KeyError:
|
101 |
+
raise KeyError("mode is not a valid split name")
|
102 |
+
# Load data features from cache or dataset file
|
103 |
+
cached_features_file = os.path.join(
|
104 |
+
cache_dir if cache_dir is not None else args.data_dir,
|
105 |
+
f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{args.task_name}",
|
106 |
+
)
|
107 |
+
label_list = self.processor.get_labels()
|
108 |
+
if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__.__name__ in (
|
109 |
+
"RobertaTokenizer",
|
110 |
+
"RobertaTokenizerFast",
|
111 |
+
"XLMRobertaTokenizer",
|
112 |
+
"BartTokenizer",
|
113 |
+
"BartTokenizerFast",
|
114 |
+
):
|
115 |
+
# HACK(label indices are swapped in RoBERTa pretrained model)
|
116 |
+
label_list[1], label_list[2] = label_list[2], label_list[1]
|
117 |
+
self.label_list = label_list
|
118 |
+
|
119 |
+
# Make sure only the first process in distributed training processes the dataset,
|
120 |
+
# and the others will use the cache.
|
121 |
+
lock_path = cached_features_file + ".lock"
|
122 |
+
with FileLock(lock_path):
|
123 |
+
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
124 |
+
start = time.time()
|
125 |
+
self.features = torch.load(cached_features_file)
|
126 |
+
logger.info(
|
127 |
+
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
|
128 |
+
)
|
129 |
+
else:
|
130 |
+
logger.info(f"Creating features from dataset file at {args.data_dir}")
|
131 |
+
|
132 |
+
if mode == Split.dev:
|
133 |
+
examples = self.processor.get_dev_examples(args.data_dir)
|
134 |
+
elif mode == Split.test:
|
135 |
+
examples = self.processor.get_test_examples(args.data_dir)
|
136 |
+
else:
|
137 |
+
examples = self.processor.get_train_examples(args.data_dir)
|
138 |
+
if limit_length is not None:
|
139 |
+
examples = examples[:limit_length]
|
140 |
+
self.features = glue_convert_examples_to_features(
|
141 |
+
examples,
|
142 |
+
tokenizer,
|
143 |
+
max_length=args.max_seq_length,
|
144 |
+
label_list=label_list,
|
145 |
+
output_mode=self.output_mode,
|
146 |
+
)
|
147 |
+
start = time.time()
|
148 |
+
torch.save(self.features, cached_features_file)
|
149 |
+
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
|
150 |
+
logger.info(
|
151 |
+
f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
|
152 |
+
)
|
153 |
+
|
154 |
+
def __len__(self):
|
155 |
+
return len(self.features)
|
156 |
+
|
157 |
+
def __getitem__(self, i) -> InputFeatures:
|
158 |
+
return self.features[i]
|
159 |
+
|
160 |
+
def get_labels(self):
|
161 |
+
return self.label_list
|
transformers_4_35_0/data/datasets/language_modeling.py
ADDED
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
import pickle
|
18 |
+
import random
|
19 |
+
import time
|
20 |
+
import warnings
|
21 |
+
from typing import Dict, List, Optional
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from filelock import FileLock
|
25 |
+
from torch.utils.data import Dataset
|
26 |
+
|
27 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
28 |
+
from ...utils import logging
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
DEPRECATION_WARNING = (
|
35 |
+
"This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
|
36 |
+
"library. You can have a look at this example script for pointers: {0}"
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
class TextDataset(Dataset):
|
41 |
+
"""
|
42 |
+
This will be superseded by a framework-agnostic approach soon.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
tokenizer: PreTrainedTokenizer,
|
48 |
+
file_path: str,
|
49 |
+
block_size: int,
|
50 |
+
overwrite_cache=False,
|
51 |
+
cache_dir: Optional[str] = None,
|
52 |
+
):
|
53 |
+
warnings.warn(
|
54 |
+
DEPRECATION_WARNING.format(
|
55 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
|
56 |
+
),
|
57 |
+
FutureWarning,
|
58 |
+
)
|
59 |
+
if os.path.isfile(file_path) is False:
|
60 |
+
raise ValueError(f"Input file path {file_path} not found")
|
61 |
+
|
62 |
+
block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
|
63 |
+
|
64 |
+
directory, filename = os.path.split(file_path)
|
65 |
+
cached_features_file = os.path.join(
|
66 |
+
cache_dir if cache_dir is not None else directory,
|
67 |
+
f"cached_lm_{tokenizer.__class__.__name__}_{block_size}_{filename}",
|
68 |
+
)
|
69 |
+
|
70 |
+
# Make sure only the first process in distributed training processes the dataset,
|
71 |
+
# and the others will use the cache.
|
72 |
+
lock_path = cached_features_file + ".lock"
|
73 |
+
with FileLock(lock_path):
|
74 |
+
if os.path.exists(cached_features_file) and not overwrite_cache:
|
75 |
+
start = time.time()
|
76 |
+
with open(cached_features_file, "rb") as handle:
|
77 |
+
self.examples = pickle.load(handle)
|
78 |
+
logger.info(
|
79 |
+
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
|
80 |
+
)
|
81 |
+
|
82 |
+
else:
|
83 |
+
logger.info(f"Creating features from dataset file at {directory}")
|
84 |
+
|
85 |
+
self.examples = []
|
86 |
+
with open(file_path, encoding="utf-8") as f:
|
87 |
+
text = f.read()
|
88 |
+
|
89 |
+
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
90 |
+
|
91 |
+
for i in range(0, len(tokenized_text) - block_size + 1, block_size): # Truncate in block of block_size
|
92 |
+
self.examples.append(
|
93 |
+
tokenizer.build_inputs_with_special_tokens(tokenized_text[i : i + block_size])
|
94 |
+
)
|
95 |
+
# Note that we are losing the last truncated example here for the sake of simplicity (no padding)
|
96 |
+
# If your dataset is small, first you should look for a bigger one :-) and second you
|
97 |
+
# can change this behavior by adding (model specific) padding.
|
98 |
+
|
99 |
+
start = time.time()
|
100 |
+
with open(cached_features_file, "wb") as handle:
|
101 |
+
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
102 |
+
logger.info(
|
103 |
+
f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
|
104 |
+
)
|
105 |
+
|
106 |
+
def __len__(self):
|
107 |
+
return len(self.examples)
|
108 |
+
|
109 |
+
def __getitem__(self, i) -> torch.Tensor:
|
110 |
+
return torch.tensor(self.examples[i], dtype=torch.long)
|
111 |
+
|
112 |
+
|
113 |
+
class LineByLineTextDataset(Dataset):
|
114 |
+
"""
|
115 |
+
This will be superseded by a framework-agnostic approach soon.
|
116 |
+
"""
|
117 |
+
|
118 |
+
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
|
119 |
+
warnings.warn(
|
120 |
+
DEPRECATION_WARNING.format(
|
121 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
|
122 |
+
),
|
123 |
+
FutureWarning,
|
124 |
+
)
|
125 |
+
if os.path.isfile(file_path) is False:
|
126 |
+
raise ValueError(f"Input file path {file_path} not found")
|
127 |
+
# Here, we do not cache the features, operating under the assumption
|
128 |
+
# that we will soon use fast multithreaded tokenizers from the
|
129 |
+
# `tokenizers` repo everywhere =)
|
130 |
+
logger.info(f"Creating features from dataset file at {file_path}")
|
131 |
+
|
132 |
+
with open(file_path, encoding="utf-8") as f:
|
133 |
+
lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
|
134 |
+
|
135 |
+
batch_encoding = tokenizer(lines, add_special_tokens=True, truncation=True, max_length=block_size)
|
136 |
+
self.examples = batch_encoding["input_ids"]
|
137 |
+
self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
|
138 |
+
|
139 |
+
def __len__(self):
|
140 |
+
return len(self.examples)
|
141 |
+
|
142 |
+
def __getitem__(self, i) -> Dict[str, torch.tensor]:
|
143 |
+
return self.examples[i]
|
144 |
+
|
145 |
+
|
146 |
+
class LineByLineWithRefDataset(Dataset):
|
147 |
+
"""
|
148 |
+
This will be superseded by a framework-agnostic approach soon.
|
149 |
+
"""
|
150 |
+
|
151 |
+
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, ref_path: str):
|
152 |
+
warnings.warn(
|
153 |
+
DEPRECATION_WARNING.format(
|
154 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm_wwm.py"
|
155 |
+
),
|
156 |
+
FutureWarning,
|
157 |
+
)
|
158 |
+
if os.path.isfile(file_path) is False:
|
159 |
+
raise ValueError(f"Input file path {file_path} not found")
|
160 |
+
if os.path.isfile(ref_path) is False:
|
161 |
+
raise ValueError(f"Ref file path {file_path} not found")
|
162 |
+
# Here, we do not cache the features, operating under the assumption
|
163 |
+
# that we will soon use fast multithreaded tokenizers from the
|
164 |
+
# `tokenizers` repo everywhere =)
|
165 |
+
logger.info(f"Creating features from dataset file at {file_path}")
|
166 |
+
logger.info(f"Use ref segment results at {ref_path}")
|
167 |
+
with open(file_path, encoding="utf-8") as f:
|
168 |
+
data = f.readlines() # use this method to avoid delimiter '\u2029' to split a line
|
169 |
+
data = [line.strip() for line in data if len(line) > 0 and not line.isspace()]
|
170 |
+
# Get ref inf from file
|
171 |
+
with open(ref_path, encoding="utf-8") as f:
|
172 |
+
ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
|
173 |
+
if len(data) != len(ref):
|
174 |
+
raise ValueError(
|
175 |
+
f"Length of Input file should be equal to Ref file. But the length of {file_path} is {len(data)} "
|
176 |
+
f"while length of {ref_path} is {len(ref)}"
|
177 |
+
)
|
178 |
+
|
179 |
+
batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size)
|
180 |
+
self.examples = batch_encoding["input_ids"]
|
181 |
+
self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
|
182 |
+
|
183 |
+
n = len(self.examples)
|
184 |
+
for i in range(n):
|
185 |
+
self.examples[i]["chinese_ref"] = torch.tensor(ref[i], dtype=torch.long)
|
186 |
+
|
187 |
+
def __len__(self):
|
188 |
+
return len(self.examples)
|
189 |
+
|
190 |
+
def __getitem__(self, i) -> Dict[str, torch.tensor]:
|
191 |
+
return self.examples[i]
|
192 |
+
|
193 |
+
|
194 |
+
class LineByLineWithSOPTextDataset(Dataset):
|
195 |
+
"""
|
196 |
+
Dataset for sentence order prediction task, prepare sentence pairs for SOP task
|
197 |
+
"""
|
198 |
+
|
199 |
+
def __init__(self, tokenizer: PreTrainedTokenizer, file_dir: str, block_size: int):
|
200 |
+
warnings.warn(
|
201 |
+
DEPRECATION_WARNING.format(
|
202 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
|
203 |
+
),
|
204 |
+
FutureWarning,
|
205 |
+
)
|
206 |
+
if os.path.isdir(file_dir) is False:
|
207 |
+
raise ValueError(f"{file_dir} is not a directory")
|
208 |
+
logger.info(f"Creating features from dataset file folder at {file_dir}")
|
209 |
+
self.examples = []
|
210 |
+
# TODO: randomness could apply a random seed, ex. rng = random.Random(random_seed)
|
211 |
+
# file path looks like ./dataset/wiki_1, ./dataset/wiki_2
|
212 |
+
for file_name in os.listdir(file_dir):
|
213 |
+
file_path = os.path.join(file_dir, file_name)
|
214 |
+
if os.path.isfile(file_path) is False:
|
215 |
+
raise ValueError(f"{file_path} is not a file")
|
216 |
+
article_open = False
|
217 |
+
with open(file_path, encoding="utf-8") as f:
|
218 |
+
original_lines = f.readlines()
|
219 |
+
article_lines = []
|
220 |
+
for line in original_lines:
|
221 |
+
if "<doc id=" in line:
|
222 |
+
article_open = True
|
223 |
+
elif "</doc>" in line:
|
224 |
+
article_open = False
|
225 |
+
document = [
|
226 |
+
tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line))
|
227 |
+
for line in article_lines[1:]
|
228 |
+
if (len(line) > 0 and not line.isspace())
|
229 |
+
]
|
230 |
+
|
231 |
+
examples = self.create_examples_from_document(document, block_size, tokenizer)
|
232 |
+
self.examples.extend(examples)
|
233 |
+
article_lines = []
|
234 |
+
else:
|
235 |
+
if article_open:
|
236 |
+
article_lines.append(line)
|
237 |
+
|
238 |
+
logger.info("Dataset parse finished.")
|
239 |
+
|
240 |
+
def create_examples_from_document(self, document, block_size, tokenizer, short_seq_prob=0.1):
|
241 |
+
"""Creates examples for a single document."""
|
242 |
+
|
243 |
+
# Account for special tokens
|
244 |
+
max_num_tokens = block_size - tokenizer.num_special_tokens_to_add(pair=True)
|
245 |
+
|
246 |
+
# We *usually* want to fill up the entire sequence since we are padding
|
247 |
+
# to `block_size` anyways, so short sequences are generally wasted
|
248 |
+
# computation. However, we *sometimes*
|
249 |
+
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
250 |
+
# sequences to minimize the mismatch between pretraining and fine-tuning.
|
251 |
+
# The `target_seq_length` is just a rough target however, whereas
|
252 |
+
# `block_size` is a hard limit.
|
253 |
+
target_seq_length = max_num_tokens
|
254 |
+
if random.random() < short_seq_prob:
|
255 |
+
target_seq_length = random.randint(2, max_num_tokens)
|
256 |
+
|
257 |
+
# We DON'T just concatenate all of the tokens from a document into a long
|
258 |
+
# sequence and choose an arbitrary split point because this would make the
|
259 |
+
# next sentence prediction task too easy. Instead, we split the input into
|
260 |
+
# segments "A" and "B" based on the actual "sentences" provided by the user
|
261 |
+
# input.
|
262 |
+
examples = []
|
263 |
+
current_chunk = [] # a buffer stored current working segments
|
264 |
+
current_length = 0
|
265 |
+
i = 0
|
266 |
+
while i < len(document):
|
267 |
+
segment = document[i] # get a segment
|
268 |
+
if not segment:
|
269 |
+
i += 1
|
270 |
+
continue
|
271 |
+
current_chunk.append(segment) # add a segment to current chunk
|
272 |
+
current_length += len(segment) # overall token length
|
273 |
+
# if current length goes to the target length or reaches the end of file, start building token a and b
|
274 |
+
if i == len(document) - 1 or current_length >= target_seq_length:
|
275 |
+
if current_chunk:
|
276 |
+
# `a_end` is how many segments from `current_chunk` go into the `A` (first) sentence.
|
277 |
+
a_end = 1
|
278 |
+
# if current chunk has more than 2 sentences, pick part of it `A` (first) sentence
|
279 |
+
if len(current_chunk) >= 2:
|
280 |
+
a_end = random.randint(1, len(current_chunk) - 1)
|
281 |
+
# token a
|
282 |
+
tokens_a = []
|
283 |
+
for j in range(a_end):
|
284 |
+
tokens_a.extend(current_chunk[j])
|
285 |
+
|
286 |
+
# token b
|
287 |
+
tokens_b = []
|
288 |
+
for j in range(a_end, len(current_chunk)):
|
289 |
+
tokens_b.extend(current_chunk[j])
|
290 |
+
|
291 |
+
if len(tokens_a) == 0 or len(tokens_b) == 0:
|
292 |
+
continue
|
293 |
+
|
294 |
+
# switch tokens_a and tokens_b randomly
|
295 |
+
if random.random() < 0.5:
|
296 |
+
is_next = False
|
297 |
+
tokens_a, tokens_b = tokens_b, tokens_a
|
298 |
+
else:
|
299 |
+
is_next = True
|
300 |
+
|
301 |
+
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
|
302 |
+
"""Truncates a pair of sequences to a maximum sequence length."""
|
303 |
+
while True:
|
304 |
+
total_length = len(tokens_a) + len(tokens_b)
|
305 |
+
if total_length <= max_num_tokens:
|
306 |
+
break
|
307 |
+
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
|
308 |
+
if not (len(trunc_tokens) >= 1):
|
309 |
+
raise ValueError("Sequence length to be truncated must be no less than one")
|
310 |
+
# We want to sometimes truncate from the front and sometimes from the
|
311 |
+
# back to add more randomness and avoid biases.
|
312 |
+
if random.random() < 0.5:
|
313 |
+
del trunc_tokens[0]
|
314 |
+
else:
|
315 |
+
trunc_tokens.pop()
|
316 |
+
|
317 |
+
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)
|
318 |
+
if not (len(tokens_a) >= 1):
|
319 |
+
raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
|
320 |
+
if not (len(tokens_b) >= 1):
|
321 |
+
raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
|
322 |
+
|
323 |
+
# add special tokens
|
324 |
+
input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
|
325 |
+
# add token type ids, 0 for sentence a, 1 for sentence b
|
326 |
+
token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
|
327 |
+
|
328 |
+
example = {
|
329 |
+
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
330 |
+
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
|
331 |
+
"sentence_order_label": torch.tensor(0 if is_next else 1, dtype=torch.long),
|
332 |
+
}
|
333 |
+
examples.append(example)
|
334 |
+
current_chunk = [] # clear current chunk
|
335 |
+
current_length = 0 # reset current text length
|
336 |
+
i += 1 # go to next line
|
337 |
+
return examples
|
338 |
+
|
339 |
+
def __len__(self):
|
340 |
+
return len(self.examples)
|
341 |
+
|
342 |
+
def __getitem__(self, i) -> Dict[str, torch.tensor]:
|
343 |
+
return self.examples[i]
|
344 |
+
|
345 |
+
|
346 |
+
class TextDatasetForNextSentencePrediction(Dataset):
|
347 |
+
"""
|
348 |
+
This will be superseded by a framework-agnostic approach soon.
|
349 |
+
"""
|
350 |
+
|
351 |
+
def __init__(
|
352 |
+
self,
|
353 |
+
tokenizer: PreTrainedTokenizer,
|
354 |
+
file_path: str,
|
355 |
+
block_size: int,
|
356 |
+
overwrite_cache=False,
|
357 |
+
short_seq_probability=0.1,
|
358 |
+
nsp_probability=0.5,
|
359 |
+
):
|
360 |
+
warnings.warn(
|
361 |
+
DEPRECATION_WARNING.format(
|
362 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py"
|
363 |
+
),
|
364 |
+
FutureWarning,
|
365 |
+
)
|
366 |
+
if not os.path.isfile(file_path):
|
367 |
+
raise ValueError(f"Input file path {file_path} not found")
|
368 |
+
|
369 |
+
self.short_seq_probability = short_seq_probability
|
370 |
+
self.nsp_probability = nsp_probability
|
371 |
+
|
372 |
+
directory, filename = os.path.split(file_path)
|
373 |
+
cached_features_file = os.path.join(
|
374 |
+
directory,
|
375 |
+
f"cached_nsp_{tokenizer.__class__.__name__}_{block_size}_{filename}",
|
376 |
+
)
|
377 |
+
|
378 |
+
self.tokenizer = tokenizer
|
379 |
+
|
380 |
+
# Make sure only the first process in distributed training processes the dataset,
|
381 |
+
# and the others will use the cache.
|
382 |
+
lock_path = cached_features_file + ".lock"
|
383 |
+
|
384 |
+
# Input file format:
|
385 |
+
# (1) One sentence per line. These should ideally be actual sentences, not
|
386 |
+
# entire paragraphs or arbitrary spans of text. (Because we use the
|
387 |
+
# sentence boundaries for the "next sentence prediction" task).
|
388 |
+
# (2) Blank lines between documents. Document boundaries are needed so
|
389 |
+
# that the "next sentence prediction" task doesn't span between documents.
|
390 |
+
#
|
391 |
+
# Example:
|
392 |
+
# I am very happy.
|
393 |
+
# Here is the second sentence.
|
394 |
+
#
|
395 |
+
# A new document.
|
396 |
+
|
397 |
+
with FileLock(lock_path):
|
398 |
+
if os.path.exists(cached_features_file) and not overwrite_cache:
|
399 |
+
start = time.time()
|
400 |
+
with open(cached_features_file, "rb") as handle:
|
401 |
+
self.examples = pickle.load(handle)
|
402 |
+
logger.info(
|
403 |
+
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
|
404 |
+
)
|
405 |
+
else:
|
406 |
+
logger.info(f"Creating features from dataset file at {directory}")
|
407 |
+
|
408 |
+
self.documents = [[]]
|
409 |
+
with open(file_path, encoding="utf-8") as f:
|
410 |
+
while True:
|
411 |
+
line = f.readline()
|
412 |
+
if not line:
|
413 |
+
break
|
414 |
+
line = line.strip()
|
415 |
+
|
416 |
+
# Empty lines are used as document delimiters
|
417 |
+
if not line and len(self.documents[-1]) != 0:
|
418 |
+
self.documents.append([])
|
419 |
+
tokens = tokenizer.tokenize(line)
|
420 |
+
tokens = tokenizer.convert_tokens_to_ids(tokens)
|
421 |
+
if tokens:
|
422 |
+
self.documents[-1].append(tokens)
|
423 |
+
|
424 |
+
logger.info(f"Creating examples from {len(self.documents)} documents.")
|
425 |
+
self.examples = []
|
426 |
+
for doc_index, document in enumerate(self.documents):
|
427 |
+
self.create_examples_from_document(document, doc_index, block_size)
|
428 |
+
|
429 |
+
start = time.time()
|
430 |
+
with open(cached_features_file, "wb") as handle:
|
431 |
+
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
432 |
+
logger.info(
|
433 |
+
f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
|
434 |
+
)
|
435 |
+
|
436 |
+
def create_examples_from_document(self, document: List[List[int]], doc_index: int, block_size: int):
|
437 |
+
"""Creates examples for a single document."""
|
438 |
+
|
439 |
+
max_num_tokens = block_size - self.tokenizer.num_special_tokens_to_add(pair=True)
|
440 |
+
|
441 |
+
# We *usually* want to fill up the entire sequence since we are padding
|
442 |
+
# to `block_size` anyways, so short sequences are generally wasted
|
443 |
+
# computation. However, we *sometimes*
|
444 |
+
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
|
445 |
+
# sequences to minimize the mismatch between pretraining and fine-tuning.
|
446 |
+
# The `target_seq_length` is just a rough target however, whereas
|
447 |
+
# `block_size` is a hard limit.
|
448 |
+
target_seq_length = max_num_tokens
|
449 |
+
if random.random() < self.short_seq_probability:
|
450 |
+
target_seq_length = random.randint(2, max_num_tokens)
|
451 |
+
|
452 |
+
current_chunk = [] # a buffer stored current working segments
|
453 |
+
current_length = 0
|
454 |
+
i = 0
|
455 |
+
|
456 |
+
while i < len(document):
|
457 |
+
segment = document[i]
|
458 |
+
current_chunk.append(segment)
|
459 |
+
current_length += len(segment)
|
460 |
+
if i == len(document) - 1 or current_length >= target_seq_length:
|
461 |
+
if current_chunk:
|
462 |
+
# `a_end` is how many segments from `current_chunk` go into the `A`
|
463 |
+
# (first) sentence.
|
464 |
+
a_end = 1
|
465 |
+
if len(current_chunk) >= 2:
|
466 |
+
a_end = random.randint(1, len(current_chunk) - 1)
|
467 |
+
|
468 |
+
tokens_a = []
|
469 |
+
for j in range(a_end):
|
470 |
+
tokens_a.extend(current_chunk[j])
|
471 |
+
|
472 |
+
tokens_b = []
|
473 |
+
|
474 |
+
if len(current_chunk) == 1 or random.random() < self.nsp_probability:
|
475 |
+
is_random_next = True
|
476 |
+
target_b_length = target_seq_length - len(tokens_a)
|
477 |
+
|
478 |
+
# This should rarely go for more than one iteration for large
|
479 |
+
# corpora. However, just to be careful, we try to make sure that
|
480 |
+
# the random document is not the same as the document
|
481 |
+
# we're processing.
|
482 |
+
for _ in range(10):
|
483 |
+
random_document_index = random.randint(0, len(self.documents) - 1)
|
484 |
+
if random_document_index != doc_index:
|
485 |
+
break
|
486 |
+
|
487 |
+
random_document = self.documents[random_document_index]
|
488 |
+
random_start = random.randint(0, len(random_document) - 1)
|
489 |
+
for j in range(random_start, len(random_document)):
|
490 |
+
tokens_b.extend(random_document[j])
|
491 |
+
if len(tokens_b) >= target_b_length:
|
492 |
+
break
|
493 |
+
# We didn't actually use these segments so we "put them back" so
|
494 |
+
# they don't go to waste.
|
495 |
+
num_unused_segments = len(current_chunk) - a_end
|
496 |
+
i -= num_unused_segments
|
497 |
+
# Actual next
|
498 |
+
else:
|
499 |
+
is_random_next = False
|
500 |
+
for j in range(a_end, len(current_chunk)):
|
501 |
+
tokens_b.extend(current_chunk[j])
|
502 |
+
|
503 |
+
if not (len(tokens_a) >= 1):
|
504 |
+
raise ValueError(f"Length of sequence a is {len(tokens_a)} which must be no less than 1")
|
505 |
+
if not (len(tokens_b) >= 1):
|
506 |
+
raise ValueError(f"Length of sequence b is {len(tokens_b)} which must be no less than 1")
|
507 |
+
|
508 |
+
# add special tokens
|
509 |
+
input_ids = self.tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
|
510 |
+
# add token type ids, 0 for sentence a, 1 for sentence b
|
511 |
+
token_type_ids = self.tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
|
512 |
+
|
513 |
+
example = {
|
514 |
+
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
515 |
+
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
|
516 |
+
"next_sentence_label": torch.tensor(1 if is_random_next else 0, dtype=torch.long),
|
517 |
+
}
|
518 |
+
|
519 |
+
self.examples.append(example)
|
520 |
+
|
521 |
+
current_chunk = []
|
522 |
+
current_length = 0
|
523 |
+
|
524 |
+
i += 1
|
525 |
+
|
526 |
+
def __len__(self):
|
527 |
+
return len(self.examples)
|
528 |
+
|
529 |
+
def __getitem__(self, i):
|
530 |
+
return self.examples[i]
|
transformers_4_35_0/data/datasets/squad.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import time
|
17 |
+
from dataclasses import dataclass, field
|
18 |
+
from enum import Enum
|
19 |
+
from typing import Dict, List, Optional, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from filelock import FileLock
|
23 |
+
from torch.utils.data import Dataset
|
24 |
+
|
25 |
+
from ...models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
26 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
27 |
+
from ...utils import logging
|
28 |
+
from ..processors.squad import SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__)
|
32 |
+
|
33 |
+
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
|
34 |
+
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
35 |
+
|
36 |
+
|
37 |
+
@dataclass
|
38 |
+
class SquadDataTrainingArguments:
|
39 |
+
"""
|
40 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
41 |
+
"""
|
42 |
+
|
43 |
+
model_type: str = field(
|
44 |
+
default=None, metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_TYPES)}
|
45 |
+
)
|
46 |
+
data_dir: str = field(
|
47 |
+
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
|
48 |
+
)
|
49 |
+
max_seq_length: int = field(
|
50 |
+
default=128,
|
51 |
+
metadata={
|
52 |
+
"help": (
|
53 |
+
"The maximum total input sequence length after tokenization. Sequences longer "
|
54 |
+
"than this will be truncated, sequences shorter will be padded."
|
55 |
+
)
|
56 |
+
},
|
57 |
+
)
|
58 |
+
doc_stride: int = field(
|
59 |
+
default=128,
|
60 |
+
metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."},
|
61 |
+
)
|
62 |
+
max_query_length: int = field(
|
63 |
+
default=64,
|
64 |
+
metadata={
|
65 |
+
"help": (
|
66 |
+
"The maximum number of tokens for the question. Questions longer than this will "
|
67 |
+
"be truncated to this length."
|
68 |
+
)
|
69 |
+
},
|
70 |
+
)
|
71 |
+
max_answer_length: int = field(
|
72 |
+
default=30,
|
73 |
+
metadata={
|
74 |
+
"help": (
|
75 |
+
"The maximum length of an answer that can be generated. This is needed because the start "
|
76 |
+
"and end predictions are not conditioned on one another."
|
77 |
+
)
|
78 |
+
},
|
79 |
+
)
|
80 |
+
overwrite_cache: bool = field(
|
81 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
82 |
+
)
|
83 |
+
version_2_with_negative: bool = field(
|
84 |
+
default=False, metadata={"help": "If true, the SQuAD examples contain some that do not have an answer."}
|
85 |
+
)
|
86 |
+
null_score_diff_threshold: float = field(
|
87 |
+
default=0.0, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
|
88 |
+
)
|
89 |
+
n_best_size: int = field(
|
90 |
+
default=20, metadata={"help": "If null_score - best_non_null is greater than the threshold predict null."}
|
91 |
+
)
|
92 |
+
lang_id: int = field(
|
93 |
+
default=0,
|
94 |
+
metadata={
|
95 |
+
"help": (
|
96 |
+
"language id of input for language-specific xlm models (see"
|
97 |
+
" tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
|
98 |
+
)
|
99 |
+
},
|
100 |
+
)
|
101 |
+
threads: int = field(default=1, metadata={"help": "multiple threads for converting example to features"})
|
102 |
+
|
103 |
+
|
104 |
+
class Split(Enum):
|
105 |
+
train = "train"
|
106 |
+
dev = "dev"
|
107 |
+
|
108 |
+
|
109 |
+
class SquadDataset(Dataset):
|
110 |
+
"""
|
111 |
+
This will be superseded by a framework-agnostic approach soon.
|
112 |
+
"""
|
113 |
+
|
114 |
+
args: SquadDataTrainingArguments
|
115 |
+
features: List[SquadFeatures]
|
116 |
+
mode: Split
|
117 |
+
is_language_sensitive: bool
|
118 |
+
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
args: SquadDataTrainingArguments,
|
122 |
+
tokenizer: PreTrainedTokenizer,
|
123 |
+
limit_length: Optional[int] = None,
|
124 |
+
mode: Union[str, Split] = Split.train,
|
125 |
+
is_language_sensitive: Optional[bool] = False,
|
126 |
+
cache_dir: Optional[str] = None,
|
127 |
+
dataset_format: Optional[str] = "pt",
|
128 |
+
):
|
129 |
+
self.args = args
|
130 |
+
self.is_language_sensitive = is_language_sensitive
|
131 |
+
self.processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
|
132 |
+
if isinstance(mode, str):
|
133 |
+
try:
|
134 |
+
mode = Split[mode]
|
135 |
+
except KeyError:
|
136 |
+
raise KeyError("mode is not a valid split name")
|
137 |
+
self.mode = mode
|
138 |
+
# Load data features from cache or dataset file
|
139 |
+
version_tag = "v2" if args.version_2_with_negative else "v1"
|
140 |
+
cached_features_file = os.path.join(
|
141 |
+
cache_dir if cache_dir is not None else args.data_dir,
|
142 |
+
f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{version_tag}",
|
143 |
+
)
|
144 |
+
|
145 |
+
# Make sure only the first process in distributed training processes the dataset,
|
146 |
+
# and the others will use the cache.
|
147 |
+
lock_path = cached_features_file + ".lock"
|
148 |
+
with FileLock(lock_path):
|
149 |
+
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
150 |
+
start = time.time()
|
151 |
+
self.old_features = torch.load(cached_features_file)
|
152 |
+
|
153 |
+
# Legacy cache files have only features, while new cache files
|
154 |
+
# will have dataset and examples also.
|
155 |
+
self.features = self.old_features["features"]
|
156 |
+
self.dataset = self.old_features.get("dataset", None)
|
157 |
+
self.examples = self.old_features.get("examples", None)
|
158 |
+
logger.info(
|
159 |
+
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
|
160 |
+
)
|
161 |
+
|
162 |
+
if self.dataset is None or self.examples is None:
|
163 |
+
logger.warning(
|
164 |
+
f"Deleting cached file {cached_features_file} will allow dataset and examples to be cached in"
|
165 |
+
" future run"
|
166 |
+
)
|
167 |
+
else:
|
168 |
+
if mode == Split.dev:
|
169 |
+
self.examples = self.processor.get_dev_examples(args.data_dir)
|
170 |
+
else:
|
171 |
+
self.examples = self.processor.get_train_examples(args.data_dir)
|
172 |
+
|
173 |
+
self.features, self.dataset = squad_convert_examples_to_features(
|
174 |
+
examples=self.examples,
|
175 |
+
tokenizer=tokenizer,
|
176 |
+
max_seq_length=args.max_seq_length,
|
177 |
+
doc_stride=args.doc_stride,
|
178 |
+
max_query_length=args.max_query_length,
|
179 |
+
is_training=mode == Split.train,
|
180 |
+
threads=args.threads,
|
181 |
+
return_dataset=dataset_format,
|
182 |
+
)
|
183 |
+
|
184 |
+
start = time.time()
|
185 |
+
torch.save(
|
186 |
+
{"features": self.features, "dataset": self.dataset, "examples": self.examples},
|
187 |
+
cached_features_file,
|
188 |
+
)
|
189 |
+
# ^ This seems to take a lot of time so I want to investigate why and how we can improve.
|
190 |
+
logger.info(
|
191 |
+
f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]"
|
192 |
+
)
|
193 |
+
|
194 |
+
def __len__(self):
|
195 |
+
return len(self.features)
|
196 |
+
|
197 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
198 |
+
# Convert to Tensors and build dataset
|
199 |
+
feature = self.features[i]
|
200 |
+
|
201 |
+
input_ids = torch.tensor(feature.input_ids, dtype=torch.long)
|
202 |
+
attention_mask = torch.tensor(feature.attention_mask, dtype=torch.long)
|
203 |
+
token_type_ids = torch.tensor(feature.token_type_ids, dtype=torch.long)
|
204 |
+
cls_index = torch.tensor(feature.cls_index, dtype=torch.long)
|
205 |
+
p_mask = torch.tensor(feature.p_mask, dtype=torch.float)
|
206 |
+
is_impossible = torch.tensor(feature.is_impossible, dtype=torch.float)
|
207 |
+
|
208 |
+
inputs = {
|
209 |
+
"input_ids": input_ids,
|
210 |
+
"attention_mask": attention_mask,
|
211 |
+
"token_type_ids": token_type_ids,
|
212 |
+
}
|
213 |
+
|
214 |
+
if self.args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
|
215 |
+
del inputs["token_type_ids"]
|
216 |
+
|
217 |
+
if self.args.model_type in ["xlnet", "xlm"]:
|
218 |
+
inputs.update({"cls_index": cls_index, "p_mask": p_mask})
|
219 |
+
if self.args.version_2_with_negative:
|
220 |
+
inputs.update({"is_impossible": is_impossible})
|
221 |
+
if self.is_language_sensitive:
|
222 |
+
inputs.update({"langs": (torch.ones(input_ids.shape, dtype=torch.int64) * self.args.lang_id)})
|
223 |
+
|
224 |
+
if self.mode == Split.train:
|
225 |
+
start_positions = torch.tensor(feature.start_position, dtype=torch.long)
|
226 |
+
end_positions = torch.tensor(feature.end_position, dtype=torch.long)
|
227 |
+
inputs.update({"start_positions": start_positions, "end_positions": end_positions})
|
228 |
+
|
229 |
+
return inputs
|
transformers_4_35_0/data/metrics/__init__.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
2 |
+
# you may not use this file except in compliance with the License.
|
3 |
+
# You may obtain a copy of the License at
|
4 |
+
#
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
#
|
7 |
+
# Unless required by applicable law or agreed to in writing, software
|
8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
10 |
+
# See the License for the specific language governing permissions and
|
11 |
+
# limitations under the License.
|
12 |
+
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
from ...utils import is_sklearn_available, requires_backends
|
16 |
+
|
17 |
+
|
18 |
+
if is_sklearn_available():
|
19 |
+
from scipy.stats import pearsonr, spearmanr
|
20 |
+
from sklearn.metrics import f1_score, matthews_corrcoef
|
21 |
+
|
22 |
+
|
23 |
+
DEPRECATION_WARNING = (
|
24 |
+
"This metric will be removed from the library soon, metrics should be handled with the 🤗 Evaluate "
|
25 |
+
"library. You can have a look at this example script for pointers: "
|
26 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def simple_accuracy(preds, labels):
|
31 |
+
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
32 |
+
requires_backends(simple_accuracy, "sklearn")
|
33 |
+
return (preds == labels).mean()
|
34 |
+
|
35 |
+
|
36 |
+
def acc_and_f1(preds, labels):
|
37 |
+
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
38 |
+
requires_backends(acc_and_f1, "sklearn")
|
39 |
+
acc = simple_accuracy(preds, labels)
|
40 |
+
f1 = f1_score(y_true=labels, y_pred=preds)
|
41 |
+
return {
|
42 |
+
"acc": acc,
|
43 |
+
"f1": f1,
|
44 |
+
"acc_and_f1": (acc + f1) / 2,
|
45 |
+
}
|
46 |
+
|
47 |
+
|
48 |
+
def pearson_and_spearman(preds, labels):
|
49 |
+
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
50 |
+
requires_backends(pearson_and_spearman, "sklearn")
|
51 |
+
pearson_corr = pearsonr(preds, labels)[0]
|
52 |
+
spearman_corr = spearmanr(preds, labels)[0]
|
53 |
+
return {
|
54 |
+
"pearson": pearson_corr,
|
55 |
+
"spearmanr": spearman_corr,
|
56 |
+
"corr": (pearson_corr + spearman_corr) / 2,
|
57 |
+
}
|
58 |
+
|
59 |
+
|
60 |
+
def glue_compute_metrics(task_name, preds, labels):
|
61 |
+
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
62 |
+
requires_backends(glue_compute_metrics, "sklearn")
|
63 |
+
assert len(preds) == len(labels), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}"
|
64 |
+
if task_name == "cola":
|
65 |
+
return {"mcc": matthews_corrcoef(labels, preds)}
|
66 |
+
elif task_name == "sst-2":
|
67 |
+
return {"acc": simple_accuracy(preds, labels)}
|
68 |
+
elif task_name == "mrpc":
|
69 |
+
return acc_and_f1(preds, labels)
|
70 |
+
elif task_name == "sts-b":
|
71 |
+
return pearson_and_spearman(preds, labels)
|
72 |
+
elif task_name == "qqp":
|
73 |
+
return acc_and_f1(preds, labels)
|
74 |
+
elif task_name == "mnli":
|
75 |
+
return {"mnli/acc": simple_accuracy(preds, labels)}
|
76 |
+
elif task_name == "mnli-mm":
|
77 |
+
return {"mnli-mm/acc": simple_accuracy(preds, labels)}
|
78 |
+
elif task_name == "qnli":
|
79 |
+
return {"acc": simple_accuracy(preds, labels)}
|
80 |
+
elif task_name == "rte":
|
81 |
+
return {"acc": simple_accuracy(preds, labels)}
|
82 |
+
elif task_name == "wnli":
|
83 |
+
return {"acc": simple_accuracy(preds, labels)}
|
84 |
+
elif task_name == "hans":
|
85 |
+
return {"acc": simple_accuracy(preds, labels)}
|
86 |
+
else:
|
87 |
+
raise KeyError(task_name)
|
88 |
+
|
89 |
+
|
90 |
+
def xnli_compute_metrics(task_name, preds, labels):
|
91 |
+
warnings.warn(DEPRECATION_WARNING, FutureWarning)
|
92 |
+
requires_backends(xnli_compute_metrics, "sklearn")
|
93 |
+
if len(preds) != len(labels):
|
94 |
+
raise ValueError(f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}")
|
95 |
+
if task_name == "xnli":
|
96 |
+
return {"acc": simple_accuracy(preds, labels)}
|
97 |
+
else:
|
98 |
+
raise KeyError(task_name)
|
transformers_4_35_0/data/metrics/squad_metrics.py
ADDED
@@ -0,0 +1,780 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Very heavily inspired by the official evaluation script for SQuAD version 2.0 which was modified by XLNet authors to
|
16 |
+
update `find_best_threshold` scripts for SQuAD V2.0
|
17 |
+
|
18 |
+
In addition to basic functionality, we also compute additional statistics and plot precision-recall curves if an
|
19 |
+
additional na_prob.json file is provided. This file is expected to map question ID's to the model's predicted
|
20 |
+
probability that a question is unanswerable.
|
21 |
+
"""
|
22 |
+
|
23 |
+
|
24 |
+
import collections
|
25 |
+
import json
|
26 |
+
import math
|
27 |
+
import re
|
28 |
+
import string
|
29 |
+
|
30 |
+
from ...models.bert import BasicTokenizer
|
31 |
+
from ...utils import logging
|
32 |
+
|
33 |
+
|
34 |
+
logger = logging.get_logger(__name__)
|
35 |
+
|
36 |
+
|
37 |
+
def normalize_answer(s):
|
38 |
+
"""Lower text and remove punctuation, articles and extra whitespace."""
|
39 |
+
|
40 |
+
def remove_articles(text):
|
41 |
+
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
|
42 |
+
return re.sub(regex, " ", text)
|
43 |
+
|
44 |
+
def white_space_fix(text):
|
45 |
+
return " ".join(text.split())
|
46 |
+
|
47 |
+
def remove_punc(text):
|
48 |
+
exclude = set(string.punctuation)
|
49 |
+
return "".join(ch for ch in text if ch not in exclude)
|
50 |
+
|
51 |
+
def lower(text):
|
52 |
+
return text.lower()
|
53 |
+
|
54 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
55 |
+
|
56 |
+
|
57 |
+
def get_tokens(s):
|
58 |
+
if not s:
|
59 |
+
return []
|
60 |
+
return normalize_answer(s).split()
|
61 |
+
|
62 |
+
|
63 |
+
def compute_exact(a_gold, a_pred):
|
64 |
+
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
|
65 |
+
|
66 |
+
|
67 |
+
def compute_f1(a_gold, a_pred):
|
68 |
+
gold_toks = get_tokens(a_gold)
|
69 |
+
pred_toks = get_tokens(a_pred)
|
70 |
+
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
|
71 |
+
num_same = sum(common.values())
|
72 |
+
if len(gold_toks) == 0 or len(pred_toks) == 0:
|
73 |
+
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
|
74 |
+
return int(gold_toks == pred_toks)
|
75 |
+
if num_same == 0:
|
76 |
+
return 0
|
77 |
+
precision = 1.0 * num_same / len(pred_toks)
|
78 |
+
recall = 1.0 * num_same / len(gold_toks)
|
79 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
80 |
+
return f1
|
81 |
+
|
82 |
+
|
83 |
+
def get_raw_scores(examples, preds):
|
84 |
+
"""
|
85 |
+
Computes the exact and f1 scores from the examples and the model predictions
|
86 |
+
"""
|
87 |
+
exact_scores = {}
|
88 |
+
f1_scores = {}
|
89 |
+
|
90 |
+
for example in examples:
|
91 |
+
qas_id = example.qas_id
|
92 |
+
gold_answers = [answer["text"] for answer in example.answers if normalize_answer(answer["text"])]
|
93 |
+
|
94 |
+
if not gold_answers:
|
95 |
+
# For unanswerable questions, only correct answer is empty string
|
96 |
+
gold_answers = [""]
|
97 |
+
|
98 |
+
if qas_id not in preds:
|
99 |
+
print(f"Missing prediction for {qas_id}")
|
100 |
+
continue
|
101 |
+
|
102 |
+
prediction = preds[qas_id]
|
103 |
+
exact_scores[qas_id] = max(compute_exact(a, prediction) for a in gold_answers)
|
104 |
+
f1_scores[qas_id] = max(compute_f1(a, prediction) for a in gold_answers)
|
105 |
+
|
106 |
+
return exact_scores, f1_scores
|
107 |
+
|
108 |
+
|
109 |
+
def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh):
|
110 |
+
new_scores = {}
|
111 |
+
for qid, s in scores.items():
|
112 |
+
pred_na = na_probs[qid] > na_prob_thresh
|
113 |
+
if pred_na:
|
114 |
+
new_scores[qid] = float(not qid_to_has_ans[qid])
|
115 |
+
else:
|
116 |
+
new_scores[qid] = s
|
117 |
+
return new_scores
|
118 |
+
|
119 |
+
|
120 |
+
def make_eval_dict(exact_scores, f1_scores, qid_list=None):
|
121 |
+
if not qid_list:
|
122 |
+
total = len(exact_scores)
|
123 |
+
return collections.OrderedDict(
|
124 |
+
[
|
125 |
+
("exact", 100.0 * sum(exact_scores.values()) / total),
|
126 |
+
("f1", 100.0 * sum(f1_scores.values()) / total),
|
127 |
+
("total", total),
|
128 |
+
]
|
129 |
+
)
|
130 |
+
else:
|
131 |
+
total = len(qid_list)
|
132 |
+
return collections.OrderedDict(
|
133 |
+
[
|
134 |
+
("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total),
|
135 |
+
("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total),
|
136 |
+
("total", total),
|
137 |
+
]
|
138 |
+
)
|
139 |
+
|
140 |
+
|
141 |
+
def merge_eval(main_eval, new_eval, prefix):
|
142 |
+
for k in new_eval:
|
143 |
+
main_eval[f"{prefix}_{k}"] = new_eval[k]
|
144 |
+
|
145 |
+
|
146 |
+
def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
|
147 |
+
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
148 |
+
cur_score = num_no_ans
|
149 |
+
best_score = cur_score
|
150 |
+
best_thresh = 0.0
|
151 |
+
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
152 |
+
for i, qid in enumerate(qid_list):
|
153 |
+
if qid not in scores:
|
154 |
+
continue
|
155 |
+
if qid_to_has_ans[qid]:
|
156 |
+
diff = scores[qid]
|
157 |
+
else:
|
158 |
+
if preds[qid]:
|
159 |
+
diff = -1
|
160 |
+
else:
|
161 |
+
diff = 0
|
162 |
+
cur_score += diff
|
163 |
+
if cur_score > best_score:
|
164 |
+
best_score = cur_score
|
165 |
+
best_thresh = na_probs[qid]
|
166 |
+
|
167 |
+
has_ans_score, has_ans_cnt = 0, 0
|
168 |
+
for qid in qid_list:
|
169 |
+
if not qid_to_has_ans[qid]:
|
170 |
+
continue
|
171 |
+
has_ans_cnt += 1
|
172 |
+
|
173 |
+
if qid not in scores:
|
174 |
+
continue
|
175 |
+
has_ans_score += scores[qid]
|
176 |
+
|
177 |
+
return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
|
178 |
+
|
179 |
+
|
180 |
+
def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
181 |
+
best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
|
182 |
+
best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
|
183 |
+
main_eval["best_exact"] = best_exact
|
184 |
+
main_eval["best_exact_thresh"] = exact_thresh
|
185 |
+
main_eval["best_f1"] = best_f1
|
186 |
+
main_eval["best_f1_thresh"] = f1_thresh
|
187 |
+
main_eval["has_ans_exact"] = has_ans_exact
|
188 |
+
main_eval["has_ans_f1"] = has_ans_f1
|
189 |
+
|
190 |
+
|
191 |
+
def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
|
192 |
+
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
193 |
+
cur_score = num_no_ans
|
194 |
+
best_score = cur_score
|
195 |
+
best_thresh = 0.0
|
196 |
+
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
197 |
+
for _, qid in enumerate(qid_list):
|
198 |
+
if qid not in scores:
|
199 |
+
continue
|
200 |
+
if qid_to_has_ans[qid]:
|
201 |
+
diff = scores[qid]
|
202 |
+
else:
|
203 |
+
if preds[qid]:
|
204 |
+
diff = -1
|
205 |
+
else:
|
206 |
+
diff = 0
|
207 |
+
cur_score += diff
|
208 |
+
if cur_score > best_score:
|
209 |
+
best_score = cur_score
|
210 |
+
best_thresh = na_probs[qid]
|
211 |
+
return 100.0 * best_score / len(scores), best_thresh
|
212 |
+
|
213 |
+
|
214 |
+
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
215 |
+
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
|
216 |
+
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
|
217 |
+
|
218 |
+
main_eval["best_exact"] = best_exact
|
219 |
+
main_eval["best_exact_thresh"] = exact_thresh
|
220 |
+
main_eval["best_f1"] = best_f1
|
221 |
+
main_eval["best_f1_thresh"] = f1_thresh
|
222 |
+
|
223 |
+
|
224 |
+
def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0):
|
225 |
+
qas_id_to_has_answer = {example.qas_id: bool(example.answers) for example in examples}
|
226 |
+
has_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if has_answer]
|
227 |
+
no_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if not has_answer]
|
228 |
+
|
229 |
+
if no_answer_probs is None:
|
230 |
+
no_answer_probs = {k: 0.0 for k in preds}
|
231 |
+
|
232 |
+
exact, f1 = get_raw_scores(examples, preds)
|
233 |
+
|
234 |
+
exact_threshold = apply_no_ans_threshold(
|
235 |
+
exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold
|
236 |
+
)
|
237 |
+
f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold)
|
238 |
+
|
239 |
+
evaluation = make_eval_dict(exact_threshold, f1_threshold)
|
240 |
+
|
241 |
+
if has_answer_qids:
|
242 |
+
has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids)
|
243 |
+
merge_eval(evaluation, has_ans_eval, "HasAns")
|
244 |
+
|
245 |
+
if no_answer_qids:
|
246 |
+
no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids)
|
247 |
+
merge_eval(evaluation, no_ans_eval, "NoAns")
|
248 |
+
|
249 |
+
if no_answer_probs:
|
250 |
+
find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer)
|
251 |
+
|
252 |
+
return evaluation
|
253 |
+
|
254 |
+
|
255 |
+
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
256 |
+
"""Project the tokenized prediction back to the original text."""
|
257 |
+
|
258 |
+
# When we created the data, we kept track of the alignment between original
|
259 |
+
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
|
260 |
+
# now `orig_text` contains the span of our original text corresponding to the
|
261 |
+
# span that we predicted.
|
262 |
+
#
|
263 |
+
# However, `orig_text` may contain extra characters that we don't want in
|
264 |
+
# our prediction.
|
265 |
+
#
|
266 |
+
# For example, let's say:
|
267 |
+
# pred_text = steve smith
|
268 |
+
# orig_text = Steve Smith's
|
269 |
+
#
|
270 |
+
# We don't want to return `orig_text` because it contains the extra "'s".
|
271 |
+
#
|
272 |
+
# We don't want to return `pred_text` because it's already been normalized
|
273 |
+
# (the SQuAD eval script also does punctuation stripping/lower casing but
|
274 |
+
# our tokenizer does additional normalization like stripping accent
|
275 |
+
# characters).
|
276 |
+
#
|
277 |
+
# What we really want to return is "Steve Smith".
|
278 |
+
#
|
279 |
+
# Therefore, we have to apply a semi-complicated alignment heuristic between
|
280 |
+
# `pred_text` and `orig_text` to get a character-to-character alignment. This
|
281 |
+
# can fail in certain cases in which case we just return `orig_text`.
|
282 |
+
|
283 |
+
def _strip_spaces(text):
|
284 |
+
ns_chars = []
|
285 |
+
ns_to_s_map = collections.OrderedDict()
|
286 |
+
for i, c in enumerate(text):
|
287 |
+
if c == " ":
|
288 |
+
continue
|
289 |
+
ns_to_s_map[len(ns_chars)] = i
|
290 |
+
ns_chars.append(c)
|
291 |
+
ns_text = "".join(ns_chars)
|
292 |
+
return (ns_text, ns_to_s_map)
|
293 |
+
|
294 |
+
# We first tokenize `orig_text`, strip whitespace from the result
|
295 |
+
# and `pred_text`, and check if they are the same length. If they are
|
296 |
+
# NOT the same length, the heuristic has failed. If they are the same
|
297 |
+
# length, we assume the characters are one-to-one aligned.
|
298 |
+
tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
|
299 |
+
|
300 |
+
tok_text = " ".join(tokenizer.tokenize(orig_text))
|
301 |
+
|
302 |
+
start_position = tok_text.find(pred_text)
|
303 |
+
if start_position == -1:
|
304 |
+
if verbose_logging:
|
305 |
+
logger.info(f"Unable to find text: '{pred_text}' in '{orig_text}'")
|
306 |
+
return orig_text
|
307 |
+
end_position = start_position + len(pred_text) - 1
|
308 |
+
|
309 |
+
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
|
310 |
+
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
|
311 |
+
|
312 |
+
if len(orig_ns_text) != len(tok_ns_text):
|
313 |
+
if verbose_logging:
|
314 |
+
logger.info(f"Length not equal after stripping spaces: '{orig_ns_text}' vs '{tok_ns_text}'")
|
315 |
+
return orig_text
|
316 |
+
|
317 |
+
# We then project the characters in `pred_text` back to `orig_text` using
|
318 |
+
# the character-to-character alignment.
|
319 |
+
tok_s_to_ns_map = {}
|
320 |
+
for i, tok_index in tok_ns_to_s_map.items():
|
321 |
+
tok_s_to_ns_map[tok_index] = i
|
322 |
+
|
323 |
+
orig_start_position = None
|
324 |
+
if start_position in tok_s_to_ns_map:
|
325 |
+
ns_start_position = tok_s_to_ns_map[start_position]
|
326 |
+
if ns_start_position in orig_ns_to_s_map:
|
327 |
+
orig_start_position = orig_ns_to_s_map[ns_start_position]
|
328 |
+
|
329 |
+
if orig_start_position is None:
|
330 |
+
if verbose_logging:
|
331 |
+
logger.info("Couldn't map start position")
|
332 |
+
return orig_text
|
333 |
+
|
334 |
+
orig_end_position = None
|
335 |
+
if end_position in tok_s_to_ns_map:
|
336 |
+
ns_end_position = tok_s_to_ns_map[end_position]
|
337 |
+
if ns_end_position in orig_ns_to_s_map:
|
338 |
+
orig_end_position = orig_ns_to_s_map[ns_end_position]
|
339 |
+
|
340 |
+
if orig_end_position is None:
|
341 |
+
if verbose_logging:
|
342 |
+
logger.info("Couldn't map end position")
|
343 |
+
return orig_text
|
344 |
+
|
345 |
+
output_text = orig_text[orig_start_position : (orig_end_position + 1)]
|
346 |
+
return output_text
|
347 |
+
|
348 |
+
|
349 |
+
def _get_best_indexes(logits, n_best_size):
|
350 |
+
"""Get the n-best logits from a list."""
|
351 |
+
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
|
352 |
+
|
353 |
+
best_indexes = []
|
354 |
+
for i in range(len(index_and_score)):
|
355 |
+
if i >= n_best_size:
|
356 |
+
break
|
357 |
+
best_indexes.append(index_and_score[i][0])
|
358 |
+
return best_indexes
|
359 |
+
|
360 |
+
|
361 |
+
def _compute_softmax(scores):
|
362 |
+
"""Compute softmax probability over raw logits."""
|
363 |
+
if not scores:
|
364 |
+
return []
|
365 |
+
|
366 |
+
max_score = None
|
367 |
+
for score in scores:
|
368 |
+
if max_score is None or score > max_score:
|
369 |
+
max_score = score
|
370 |
+
|
371 |
+
exp_scores = []
|
372 |
+
total_sum = 0.0
|
373 |
+
for score in scores:
|
374 |
+
x = math.exp(score - max_score)
|
375 |
+
exp_scores.append(x)
|
376 |
+
total_sum += x
|
377 |
+
|
378 |
+
probs = []
|
379 |
+
for score in exp_scores:
|
380 |
+
probs.append(score / total_sum)
|
381 |
+
return probs
|
382 |
+
|
383 |
+
|
384 |
+
def compute_predictions_logits(
|
385 |
+
all_examples,
|
386 |
+
all_features,
|
387 |
+
all_results,
|
388 |
+
n_best_size,
|
389 |
+
max_answer_length,
|
390 |
+
do_lower_case,
|
391 |
+
output_prediction_file,
|
392 |
+
output_nbest_file,
|
393 |
+
output_null_log_odds_file,
|
394 |
+
verbose_logging,
|
395 |
+
version_2_with_negative,
|
396 |
+
null_score_diff_threshold,
|
397 |
+
tokenizer,
|
398 |
+
):
|
399 |
+
"""Write final predictions to the json file and log-odds of null if needed."""
|
400 |
+
if output_prediction_file:
|
401 |
+
logger.info(f"Writing predictions to: {output_prediction_file}")
|
402 |
+
if output_nbest_file:
|
403 |
+
logger.info(f"Writing nbest to: {output_nbest_file}")
|
404 |
+
if output_null_log_odds_file and version_2_with_negative:
|
405 |
+
logger.info(f"Writing null_log_odds to: {output_null_log_odds_file}")
|
406 |
+
|
407 |
+
example_index_to_features = collections.defaultdict(list)
|
408 |
+
for feature in all_features:
|
409 |
+
example_index_to_features[feature.example_index].append(feature)
|
410 |
+
|
411 |
+
unique_id_to_result = {}
|
412 |
+
for result in all_results:
|
413 |
+
unique_id_to_result[result.unique_id] = result
|
414 |
+
|
415 |
+
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
416 |
+
"PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]
|
417 |
+
)
|
418 |
+
|
419 |
+
all_predictions = collections.OrderedDict()
|
420 |
+
all_nbest_json = collections.OrderedDict()
|
421 |
+
scores_diff_json = collections.OrderedDict()
|
422 |
+
|
423 |
+
for example_index, example in enumerate(all_examples):
|
424 |
+
features = example_index_to_features[example_index]
|
425 |
+
|
426 |
+
prelim_predictions = []
|
427 |
+
# keep track of the minimum score of null start+end of position 0
|
428 |
+
score_null = 1000000 # large and positive
|
429 |
+
min_null_feature_index = 0 # the paragraph slice with min null score
|
430 |
+
null_start_logit = 0 # the start logit at the slice with min null score
|
431 |
+
null_end_logit = 0 # the end logit at the slice with min null score
|
432 |
+
for feature_index, feature in enumerate(features):
|
433 |
+
result = unique_id_to_result[feature.unique_id]
|
434 |
+
start_indexes = _get_best_indexes(result.start_logits, n_best_size)
|
435 |
+
end_indexes = _get_best_indexes(result.end_logits, n_best_size)
|
436 |
+
# if we could have irrelevant answers, get the min score of irrelevant
|
437 |
+
if version_2_with_negative:
|
438 |
+
feature_null_score = result.start_logits[0] + result.end_logits[0]
|
439 |
+
if feature_null_score < score_null:
|
440 |
+
score_null = feature_null_score
|
441 |
+
min_null_feature_index = feature_index
|
442 |
+
null_start_logit = result.start_logits[0]
|
443 |
+
null_end_logit = result.end_logits[0]
|
444 |
+
for start_index in start_indexes:
|
445 |
+
for end_index in end_indexes:
|
446 |
+
# We could hypothetically create invalid predictions, e.g., predict
|
447 |
+
# that the start of the span is in the question. We throw out all
|
448 |
+
# invalid predictions.
|
449 |
+
if start_index >= len(feature.tokens):
|
450 |
+
continue
|
451 |
+
if end_index >= len(feature.tokens):
|
452 |
+
continue
|
453 |
+
if start_index not in feature.token_to_orig_map:
|
454 |
+
continue
|
455 |
+
if end_index not in feature.token_to_orig_map:
|
456 |
+
continue
|
457 |
+
if not feature.token_is_max_context.get(start_index, False):
|
458 |
+
continue
|
459 |
+
if end_index < start_index:
|
460 |
+
continue
|
461 |
+
length = end_index - start_index + 1
|
462 |
+
if length > max_answer_length:
|
463 |
+
continue
|
464 |
+
prelim_predictions.append(
|
465 |
+
_PrelimPrediction(
|
466 |
+
feature_index=feature_index,
|
467 |
+
start_index=start_index,
|
468 |
+
end_index=end_index,
|
469 |
+
start_logit=result.start_logits[start_index],
|
470 |
+
end_logit=result.end_logits[end_index],
|
471 |
+
)
|
472 |
+
)
|
473 |
+
if version_2_with_negative:
|
474 |
+
prelim_predictions.append(
|
475 |
+
_PrelimPrediction(
|
476 |
+
feature_index=min_null_feature_index,
|
477 |
+
start_index=0,
|
478 |
+
end_index=0,
|
479 |
+
start_logit=null_start_logit,
|
480 |
+
end_logit=null_end_logit,
|
481 |
+
)
|
482 |
+
)
|
483 |
+
prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
|
484 |
+
|
485 |
+
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
486 |
+
"NbestPrediction", ["text", "start_logit", "end_logit"]
|
487 |
+
)
|
488 |
+
|
489 |
+
seen_predictions = {}
|
490 |
+
nbest = []
|
491 |
+
for pred in prelim_predictions:
|
492 |
+
if len(nbest) >= n_best_size:
|
493 |
+
break
|
494 |
+
feature = features[pred.feature_index]
|
495 |
+
if pred.start_index > 0: # this is a non-null prediction
|
496 |
+
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
|
497 |
+
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
498 |
+
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
499 |
+
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
|
500 |
+
|
501 |
+
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
|
502 |
+
|
503 |
+
# tok_text = " ".join(tok_tokens)
|
504 |
+
#
|
505 |
+
# # De-tokenize WordPieces that have been split off.
|
506 |
+
# tok_text = tok_text.replace(" ##", "")
|
507 |
+
# tok_text = tok_text.replace("##", "")
|
508 |
+
|
509 |
+
# Clean whitespace
|
510 |
+
tok_text = tok_text.strip()
|
511 |
+
tok_text = " ".join(tok_text.split())
|
512 |
+
orig_text = " ".join(orig_tokens)
|
513 |
+
|
514 |
+
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
|
515 |
+
if final_text in seen_predictions:
|
516 |
+
continue
|
517 |
+
|
518 |
+
seen_predictions[final_text] = True
|
519 |
+
else:
|
520 |
+
final_text = ""
|
521 |
+
seen_predictions[final_text] = True
|
522 |
+
|
523 |
+
nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit))
|
524 |
+
# if we didn't include the empty option in the n-best, include it
|
525 |
+
if version_2_with_negative:
|
526 |
+
if "" not in seen_predictions:
|
527 |
+
nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit))
|
528 |
+
|
529 |
+
# In very rare edge cases we could only have single null prediction.
|
530 |
+
# So we just create a nonce prediction in this case to avoid failure.
|
531 |
+
if len(nbest) == 1:
|
532 |
+
nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
533 |
+
|
534 |
+
# In very rare edge cases we could have no valid predictions. So we
|
535 |
+
# just create a nonce prediction in this case to avoid failure.
|
536 |
+
if not nbest:
|
537 |
+
nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
|
538 |
+
|
539 |
+
if len(nbest) < 1:
|
540 |
+
raise ValueError("No valid predictions")
|
541 |
+
|
542 |
+
total_scores = []
|
543 |
+
best_non_null_entry = None
|
544 |
+
for entry in nbest:
|
545 |
+
total_scores.append(entry.start_logit + entry.end_logit)
|
546 |
+
if not best_non_null_entry:
|
547 |
+
if entry.text:
|
548 |
+
best_non_null_entry = entry
|
549 |
+
|
550 |
+
probs = _compute_softmax(total_scores)
|
551 |
+
|
552 |
+
nbest_json = []
|
553 |
+
for i, entry in enumerate(nbest):
|
554 |
+
output = collections.OrderedDict()
|
555 |
+
output["text"] = entry.text
|
556 |
+
output["probability"] = probs[i]
|
557 |
+
output["start_logit"] = entry.start_logit
|
558 |
+
output["end_logit"] = entry.end_logit
|
559 |
+
nbest_json.append(output)
|
560 |
+
|
561 |
+
if len(nbest_json) < 1:
|
562 |
+
raise ValueError("No valid predictions")
|
563 |
+
|
564 |
+
if not version_2_with_negative:
|
565 |
+
all_predictions[example.qas_id] = nbest_json[0]["text"]
|
566 |
+
else:
|
567 |
+
# predict "" iff the null score - the score of best non-null > threshold
|
568 |
+
score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit)
|
569 |
+
scores_diff_json[example.qas_id] = score_diff
|
570 |
+
if score_diff > null_score_diff_threshold:
|
571 |
+
all_predictions[example.qas_id] = ""
|
572 |
+
else:
|
573 |
+
all_predictions[example.qas_id] = best_non_null_entry.text
|
574 |
+
all_nbest_json[example.qas_id] = nbest_json
|
575 |
+
|
576 |
+
if output_prediction_file:
|
577 |
+
with open(output_prediction_file, "w") as writer:
|
578 |
+
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
579 |
+
|
580 |
+
if output_nbest_file:
|
581 |
+
with open(output_nbest_file, "w") as writer:
|
582 |
+
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
583 |
+
|
584 |
+
if output_null_log_odds_file and version_2_with_negative:
|
585 |
+
with open(output_null_log_odds_file, "w") as writer:
|
586 |
+
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
587 |
+
|
588 |
+
return all_predictions
|
589 |
+
|
590 |
+
|
591 |
+
def compute_predictions_log_probs(
|
592 |
+
all_examples,
|
593 |
+
all_features,
|
594 |
+
all_results,
|
595 |
+
n_best_size,
|
596 |
+
max_answer_length,
|
597 |
+
output_prediction_file,
|
598 |
+
output_nbest_file,
|
599 |
+
output_null_log_odds_file,
|
600 |
+
start_n_top,
|
601 |
+
end_n_top,
|
602 |
+
version_2_with_negative,
|
603 |
+
tokenizer,
|
604 |
+
verbose_logging,
|
605 |
+
):
|
606 |
+
"""
|
607 |
+
XLNet write prediction logic (more complex than Bert's). Write final predictions to the json file and log-odds of
|
608 |
+
null if needed.
|
609 |
+
|
610 |
+
Requires utils_squad_evaluate.py
|
611 |
+
"""
|
612 |
+
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
613 |
+
"PrelimPrediction", ["feature_index", "start_index", "end_index", "start_log_prob", "end_log_prob"]
|
614 |
+
)
|
615 |
+
|
616 |
+
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
617 |
+
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"]
|
618 |
+
)
|
619 |
+
|
620 |
+
logger.info(f"Writing predictions to: {output_prediction_file}")
|
621 |
+
|
622 |
+
example_index_to_features = collections.defaultdict(list)
|
623 |
+
for feature in all_features:
|
624 |
+
example_index_to_features[feature.example_index].append(feature)
|
625 |
+
|
626 |
+
unique_id_to_result = {}
|
627 |
+
for result in all_results:
|
628 |
+
unique_id_to_result[result.unique_id] = result
|
629 |
+
|
630 |
+
all_predictions = collections.OrderedDict()
|
631 |
+
all_nbest_json = collections.OrderedDict()
|
632 |
+
scores_diff_json = collections.OrderedDict()
|
633 |
+
|
634 |
+
for example_index, example in enumerate(all_examples):
|
635 |
+
features = example_index_to_features[example_index]
|
636 |
+
|
637 |
+
prelim_predictions = []
|
638 |
+
# keep track of the minimum score of null start+end of position 0
|
639 |
+
score_null = 1000000 # large and positive
|
640 |
+
|
641 |
+
for feature_index, feature in enumerate(features):
|
642 |
+
result = unique_id_to_result[feature.unique_id]
|
643 |
+
|
644 |
+
cur_null_score = result.cls_logits
|
645 |
+
|
646 |
+
# if we could have irrelevant answers, get the min score of irrelevant
|
647 |
+
score_null = min(score_null, cur_null_score)
|
648 |
+
|
649 |
+
for i in range(start_n_top):
|
650 |
+
for j in range(end_n_top):
|
651 |
+
start_log_prob = result.start_logits[i]
|
652 |
+
start_index = result.start_top_index[i]
|
653 |
+
|
654 |
+
j_index = i * end_n_top + j
|
655 |
+
|
656 |
+
end_log_prob = result.end_logits[j_index]
|
657 |
+
end_index = result.end_top_index[j_index]
|
658 |
+
|
659 |
+
# We could hypothetically create invalid predictions, e.g., predict
|
660 |
+
# that the start of the span is in the question. We throw out all
|
661 |
+
# invalid predictions.
|
662 |
+
if start_index >= feature.paragraph_len - 1:
|
663 |
+
continue
|
664 |
+
if end_index >= feature.paragraph_len - 1:
|
665 |
+
continue
|
666 |
+
|
667 |
+
if not feature.token_is_max_context.get(start_index, False):
|
668 |
+
continue
|
669 |
+
if end_index < start_index:
|
670 |
+
continue
|
671 |
+
length = end_index - start_index + 1
|
672 |
+
if length > max_answer_length:
|
673 |
+
continue
|
674 |
+
|
675 |
+
prelim_predictions.append(
|
676 |
+
_PrelimPrediction(
|
677 |
+
feature_index=feature_index,
|
678 |
+
start_index=start_index,
|
679 |
+
end_index=end_index,
|
680 |
+
start_log_prob=start_log_prob,
|
681 |
+
end_log_prob=end_log_prob,
|
682 |
+
)
|
683 |
+
)
|
684 |
+
|
685 |
+
prelim_predictions = sorted(
|
686 |
+
prelim_predictions, key=lambda x: (x.start_log_prob + x.end_log_prob), reverse=True
|
687 |
+
)
|
688 |
+
|
689 |
+
seen_predictions = {}
|
690 |
+
nbest = []
|
691 |
+
for pred in prelim_predictions:
|
692 |
+
if len(nbest) >= n_best_size:
|
693 |
+
break
|
694 |
+
feature = features[pred.feature_index]
|
695 |
+
|
696 |
+
# XLNet un-tokenizer
|
697 |
+
# Let's keep it simple for now and see if we need all this later.
|
698 |
+
#
|
699 |
+
# tok_start_to_orig_index = feature.tok_start_to_orig_index
|
700 |
+
# tok_end_to_orig_index = feature.tok_end_to_orig_index
|
701 |
+
# start_orig_pos = tok_start_to_orig_index[pred.start_index]
|
702 |
+
# end_orig_pos = tok_end_to_orig_index[pred.end_index]
|
703 |
+
# paragraph_text = example.paragraph_text
|
704 |
+
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
|
705 |
+
|
706 |
+
# Previously used Bert untokenizer
|
707 |
+
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
|
708 |
+
orig_doc_start = feature.token_to_orig_map[pred.start_index]
|
709 |
+
orig_doc_end = feature.token_to_orig_map[pred.end_index]
|
710 |
+
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
|
711 |
+
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
|
712 |
+
|
713 |
+
# Clean whitespace
|
714 |
+
tok_text = tok_text.strip()
|
715 |
+
tok_text = " ".join(tok_text.split())
|
716 |
+
orig_text = " ".join(orig_tokens)
|
717 |
+
|
718 |
+
if hasattr(tokenizer, "do_lower_case"):
|
719 |
+
do_lower_case = tokenizer.do_lower_case
|
720 |
+
else:
|
721 |
+
do_lower_case = tokenizer.do_lowercase_and_remove_accent
|
722 |
+
|
723 |
+
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
|
724 |
+
|
725 |
+
if final_text in seen_predictions:
|
726 |
+
continue
|
727 |
+
|
728 |
+
seen_predictions[final_text] = True
|
729 |
+
|
730 |
+
nbest.append(
|
731 |
+
_NbestPrediction(text=final_text, start_log_prob=pred.start_log_prob, end_log_prob=pred.end_log_prob)
|
732 |
+
)
|
733 |
+
|
734 |
+
# In very rare edge cases we could have no valid predictions. So we
|
735 |
+
# just create a nonce prediction in this case to avoid failure.
|
736 |
+
if not nbest:
|
737 |
+
nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
|
738 |
+
|
739 |
+
total_scores = []
|
740 |
+
best_non_null_entry = None
|
741 |
+
for entry in nbest:
|
742 |
+
total_scores.append(entry.start_log_prob + entry.end_log_prob)
|
743 |
+
if not best_non_null_entry:
|
744 |
+
best_non_null_entry = entry
|
745 |
+
|
746 |
+
probs = _compute_softmax(total_scores)
|
747 |
+
|
748 |
+
nbest_json = []
|
749 |
+
for i, entry in enumerate(nbest):
|
750 |
+
output = collections.OrderedDict()
|
751 |
+
output["text"] = entry.text
|
752 |
+
output["probability"] = probs[i]
|
753 |
+
output["start_log_prob"] = entry.start_log_prob
|
754 |
+
output["end_log_prob"] = entry.end_log_prob
|
755 |
+
nbest_json.append(output)
|
756 |
+
|
757 |
+
if len(nbest_json) < 1:
|
758 |
+
raise ValueError("No valid predictions")
|
759 |
+
if best_non_null_entry is None:
|
760 |
+
raise ValueError("No valid predictions")
|
761 |
+
|
762 |
+
score_diff = score_null
|
763 |
+
scores_diff_json[example.qas_id] = score_diff
|
764 |
+
# note(zhiliny): always predict best_non_null_entry
|
765 |
+
# and the evaluation script will search for the best threshold
|
766 |
+
all_predictions[example.qas_id] = best_non_null_entry.text
|
767 |
+
|
768 |
+
all_nbest_json[example.qas_id] = nbest_json
|
769 |
+
|
770 |
+
with open(output_prediction_file, "w") as writer:
|
771 |
+
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
772 |
+
|
773 |
+
with open(output_nbest_file, "w") as writer:
|
774 |
+
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
775 |
+
|
776 |
+
if version_2_with_negative:
|
777 |
+
with open(output_null_log_odds_file, "w") as writer:
|
778 |
+
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
779 |
+
|
780 |
+
return all_predictions
|
transformers_4_35_0/data/processors/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .glue import glue_convert_examples_to_features, glue_output_modes, glue_processors, glue_tasks_num_labels
|
16 |
+
from .squad import SquadExample, SquadFeatures, SquadV1Processor, SquadV2Processor, squad_convert_examples_to_features
|
17 |
+
from .utils import DataProcessor, InputExample, InputFeatures, SingleSentenceClassificationProcessor
|
18 |
+
from .xnli import xnli_output_modes, xnli_processors, xnli_tasks_num_labels
|
transformers_4_35_0/data/processors/glue.py
ADDED
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" GLUE processors and helpers"""
|
17 |
+
|
18 |
+
import os
|
19 |
+
import warnings
|
20 |
+
from dataclasses import asdict
|
21 |
+
from enum import Enum
|
22 |
+
from typing import List, Optional, Union
|
23 |
+
|
24 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
25 |
+
from ...utils import is_tf_available, logging
|
26 |
+
from .utils import DataProcessor, InputExample, InputFeatures
|
27 |
+
|
28 |
+
|
29 |
+
if is_tf_available():
|
30 |
+
import tensorflow as tf
|
31 |
+
|
32 |
+
logger = logging.get_logger(__name__)
|
33 |
+
|
34 |
+
DEPRECATION_WARNING = (
|
35 |
+
"This {0} will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets "
|
36 |
+
"library. You can have a look at this example script for pointers: "
|
37 |
+
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py"
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
def glue_convert_examples_to_features(
|
42 |
+
examples: Union[List[InputExample], "tf.data.Dataset"],
|
43 |
+
tokenizer: PreTrainedTokenizer,
|
44 |
+
max_length: Optional[int] = None,
|
45 |
+
task=None,
|
46 |
+
label_list=None,
|
47 |
+
output_mode=None,
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
Loads a data file into a list of `InputFeatures`
|
51 |
+
|
52 |
+
Args:
|
53 |
+
examples: List of `InputExamples` or `tf.data.Dataset` containing the examples.
|
54 |
+
tokenizer: Instance of a tokenizer that will tokenize the examples
|
55 |
+
max_length: Maximum example length. Defaults to the tokenizer's max_len
|
56 |
+
task: GLUE task
|
57 |
+
label_list: List of labels. Can be obtained from the processor using the `processor.get_labels()` method
|
58 |
+
output_mode: String indicating the output mode. Either `regression` or `classification`
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
If the `examples` input is a `tf.data.Dataset`, will return a `tf.data.Dataset` containing the task-specific
|
62 |
+
features. If the input is a list of `InputExamples`, will return a list of task-specific `InputFeatures` which
|
63 |
+
can be fed to the model.
|
64 |
+
|
65 |
+
"""
|
66 |
+
warnings.warn(DEPRECATION_WARNING.format("function"), FutureWarning)
|
67 |
+
if is_tf_available() and isinstance(examples, tf.data.Dataset):
|
68 |
+
if task is None:
|
69 |
+
raise ValueError("When calling glue_convert_examples_to_features from TF, the task parameter is required.")
|
70 |
+
return _tf_glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
|
71 |
+
return _glue_convert_examples_to_features(
|
72 |
+
examples, tokenizer, max_length=max_length, task=task, label_list=label_list, output_mode=output_mode
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
if is_tf_available():
|
77 |
+
|
78 |
+
def _tf_glue_convert_examples_to_features(
|
79 |
+
examples: tf.data.Dataset,
|
80 |
+
tokenizer: PreTrainedTokenizer,
|
81 |
+
task=str,
|
82 |
+
max_length: Optional[int] = None,
|
83 |
+
) -> tf.data.Dataset:
|
84 |
+
"""
|
85 |
+
Returns:
|
86 |
+
A `tf.data.Dataset` containing the task-specific features.
|
87 |
+
|
88 |
+
"""
|
89 |
+
processor = glue_processors[task]()
|
90 |
+
examples = [processor.tfds_map(processor.get_example_from_tensor_dict(example)) for example in examples]
|
91 |
+
features = glue_convert_examples_to_features(examples, tokenizer, max_length=max_length, task=task)
|
92 |
+
label_type = tf.float32 if task == "sts-b" else tf.int64
|
93 |
+
|
94 |
+
def gen():
|
95 |
+
for ex in features:
|
96 |
+
d = {k: v for k, v in asdict(ex).items() if v is not None}
|
97 |
+
label = d.pop("label")
|
98 |
+
yield (d, label)
|
99 |
+
|
100 |
+
input_names = tokenizer.model_input_names
|
101 |
+
|
102 |
+
return tf.data.Dataset.from_generator(
|
103 |
+
gen,
|
104 |
+
({k: tf.int32 for k in input_names}, label_type),
|
105 |
+
({k: tf.TensorShape([None]) for k in input_names}, tf.TensorShape([])),
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
def _glue_convert_examples_to_features(
|
110 |
+
examples: List[InputExample],
|
111 |
+
tokenizer: PreTrainedTokenizer,
|
112 |
+
max_length: Optional[int] = None,
|
113 |
+
task=None,
|
114 |
+
label_list=None,
|
115 |
+
output_mode=None,
|
116 |
+
):
|
117 |
+
if max_length is None:
|
118 |
+
max_length = tokenizer.model_max_length
|
119 |
+
|
120 |
+
if task is not None:
|
121 |
+
processor = glue_processors[task]()
|
122 |
+
if label_list is None:
|
123 |
+
label_list = processor.get_labels()
|
124 |
+
logger.info(f"Using label list {label_list} for task {task}")
|
125 |
+
if output_mode is None:
|
126 |
+
output_mode = glue_output_modes[task]
|
127 |
+
logger.info(f"Using output mode {output_mode} for task {task}")
|
128 |
+
|
129 |
+
label_map = {label: i for i, label in enumerate(label_list)}
|
130 |
+
|
131 |
+
def label_from_example(example: InputExample) -> Union[int, float, None]:
|
132 |
+
if example.label is None:
|
133 |
+
return None
|
134 |
+
if output_mode == "classification":
|
135 |
+
return label_map[example.label]
|
136 |
+
elif output_mode == "regression":
|
137 |
+
return float(example.label)
|
138 |
+
raise KeyError(output_mode)
|
139 |
+
|
140 |
+
labels = [label_from_example(example) for example in examples]
|
141 |
+
|
142 |
+
batch_encoding = tokenizer(
|
143 |
+
[(example.text_a, example.text_b) for example in examples],
|
144 |
+
max_length=max_length,
|
145 |
+
padding="max_length",
|
146 |
+
truncation=True,
|
147 |
+
)
|
148 |
+
|
149 |
+
features = []
|
150 |
+
for i in range(len(examples)):
|
151 |
+
inputs = {k: batch_encoding[k][i] for k in batch_encoding}
|
152 |
+
|
153 |
+
feature = InputFeatures(**inputs, label=labels[i])
|
154 |
+
features.append(feature)
|
155 |
+
|
156 |
+
for i, example in enumerate(examples[:5]):
|
157 |
+
logger.info("*** Example ***")
|
158 |
+
logger.info(f"guid: {example.guid}")
|
159 |
+
logger.info(f"features: {features[i]}")
|
160 |
+
|
161 |
+
return features
|
162 |
+
|
163 |
+
|
164 |
+
class OutputMode(Enum):
|
165 |
+
classification = "classification"
|
166 |
+
regression = "regression"
|
167 |
+
|
168 |
+
|
169 |
+
class MrpcProcessor(DataProcessor):
|
170 |
+
"""Processor for the MRPC data set (GLUE version)."""
|
171 |
+
|
172 |
+
def __init__(self, *args, **kwargs):
|
173 |
+
super().__init__(*args, **kwargs)
|
174 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
175 |
+
|
176 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
177 |
+
"""See base class."""
|
178 |
+
return InputExample(
|
179 |
+
tensor_dict["idx"].numpy(),
|
180 |
+
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
181 |
+
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
182 |
+
str(tensor_dict["label"].numpy()),
|
183 |
+
)
|
184 |
+
|
185 |
+
def get_train_examples(self, data_dir):
|
186 |
+
"""See base class."""
|
187 |
+
logger.info(f"LOOKING AT {os.path.join(data_dir, 'train.tsv')}")
|
188 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
189 |
+
|
190 |
+
def get_dev_examples(self, data_dir):
|
191 |
+
"""See base class."""
|
192 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
193 |
+
|
194 |
+
def get_test_examples(self, data_dir):
|
195 |
+
"""See base class."""
|
196 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
197 |
+
|
198 |
+
def get_labels(self):
|
199 |
+
"""See base class."""
|
200 |
+
return ["0", "1"]
|
201 |
+
|
202 |
+
def _create_examples(self, lines, set_type):
|
203 |
+
"""Creates examples for the training, dev and test sets."""
|
204 |
+
examples = []
|
205 |
+
for i, line in enumerate(lines):
|
206 |
+
if i == 0:
|
207 |
+
continue
|
208 |
+
guid = f"{set_type}-{i}"
|
209 |
+
text_a = line[3]
|
210 |
+
text_b = line[4]
|
211 |
+
label = None if set_type == "test" else line[0]
|
212 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
213 |
+
return examples
|
214 |
+
|
215 |
+
|
216 |
+
class MnliProcessor(DataProcessor):
|
217 |
+
"""Processor for the MultiNLI data set (GLUE version)."""
|
218 |
+
|
219 |
+
def __init__(self, *args, **kwargs):
|
220 |
+
super().__init__(*args, **kwargs)
|
221 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
222 |
+
|
223 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
224 |
+
"""See base class."""
|
225 |
+
return InputExample(
|
226 |
+
tensor_dict["idx"].numpy(),
|
227 |
+
tensor_dict["premise"].numpy().decode("utf-8"),
|
228 |
+
tensor_dict["hypothesis"].numpy().decode("utf-8"),
|
229 |
+
str(tensor_dict["label"].numpy()),
|
230 |
+
)
|
231 |
+
|
232 |
+
def get_train_examples(self, data_dir):
|
233 |
+
"""See base class."""
|
234 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
235 |
+
|
236 |
+
def get_dev_examples(self, data_dir):
|
237 |
+
"""See base class."""
|
238 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched")
|
239 |
+
|
240 |
+
def get_test_examples(self, data_dir):
|
241 |
+
"""See base class."""
|
242 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test_matched")
|
243 |
+
|
244 |
+
def get_labels(self):
|
245 |
+
"""See base class."""
|
246 |
+
return ["contradiction", "entailment", "neutral"]
|
247 |
+
|
248 |
+
def _create_examples(self, lines, set_type):
|
249 |
+
"""Creates examples for the training, dev and test sets."""
|
250 |
+
examples = []
|
251 |
+
for i, line in enumerate(lines):
|
252 |
+
if i == 0:
|
253 |
+
continue
|
254 |
+
guid = f"{set_type}-{line[0]}"
|
255 |
+
text_a = line[8]
|
256 |
+
text_b = line[9]
|
257 |
+
label = None if set_type.startswith("test") else line[-1]
|
258 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
259 |
+
return examples
|
260 |
+
|
261 |
+
|
262 |
+
class MnliMismatchedProcessor(MnliProcessor):
|
263 |
+
"""Processor for the MultiNLI Mismatched data set (GLUE version)."""
|
264 |
+
|
265 |
+
def __init__(self, *args, **kwargs):
|
266 |
+
super().__init__(*args, **kwargs)
|
267 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
268 |
+
|
269 |
+
def get_dev_examples(self, data_dir):
|
270 |
+
"""See base class."""
|
271 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched")
|
272 |
+
|
273 |
+
def get_test_examples(self, data_dir):
|
274 |
+
"""See base class."""
|
275 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched")
|
276 |
+
|
277 |
+
|
278 |
+
class ColaProcessor(DataProcessor):
|
279 |
+
"""Processor for the CoLA data set (GLUE version)."""
|
280 |
+
|
281 |
+
def __init__(self, *args, **kwargs):
|
282 |
+
super().__init__(*args, **kwargs)
|
283 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
284 |
+
|
285 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
286 |
+
"""See base class."""
|
287 |
+
return InputExample(
|
288 |
+
tensor_dict["idx"].numpy(),
|
289 |
+
tensor_dict["sentence"].numpy().decode("utf-8"),
|
290 |
+
None,
|
291 |
+
str(tensor_dict["label"].numpy()),
|
292 |
+
)
|
293 |
+
|
294 |
+
def get_train_examples(self, data_dir):
|
295 |
+
"""See base class."""
|
296 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
297 |
+
|
298 |
+
def get_dev_examples(self, data_dir):
|
299 |
+
"""See base class."""
|
300 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
301 |
+
|
302 |
+
def get_test_examples(self, data_dir):
|
303 |
+
"""See base class."""
|
304 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
305 |
+
|
306 |
+
def get_labels(self):
|
307 |
+
"""See base class."""
|
308 |
+
return ["0", "1"]
|
309 |
+
|
310 |
+
def _create_examples(self, lines, set_type):
|
311 |
+
"""Creates examples for the training, dev and test sets."""
|
312 |
+
test_mode = set_type == "test"
|
313 |
+
if test_mode:
|
314 |
+
lines = lines[1:]
|
315 |
+
text_index = 1 if test_mode else 3
|
316 |
+
examples = []
|
317 |
+
for i, line in enumerate(lines):
|
318 |
+
guid = f"{set_type}-{i}"
|
319 |
+
text_a = line[text_index]
|
320 |
+
label = None if test_mode else line[1]
|
321 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
322 |
+
return examples
|
323 |
+
|
324 |
+
|
325 |
+
class Sst2Processor(DataProcessor):
|
326 |
+
"""Processor for the SST-2 data set (GLUE version)."""
|
327 |
+
|
328 |
+
def __init__(self, *args, **kwargs):
|
329 |
+
super().__init__(*args, **kwargs)
|
330 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
331 |
+
|
332 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
333 |
+
"""See base class."""
|
334 |
+
return InputExample(
|
335 |
+
tensor_dict["idx"].numpy(),
|
336 |
+
tensor_dict["sentence"].numpy().decode("utf-8"),
|
337 |
+
None,
|
338 |
+
str(tensor_dict["label"].numpy()),
|
339 |
+
)
|
340 |
+
|
341 |
+
def get_train_examples(self, data_dir):
|
342 |
+
"""See base class."""
|
343 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
344 |
+
|
345 |
+
def get_dev_examples(self, data_dir):
|
346 |
+
"""See base class."""
|
347 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
348 |
+
|
349 |
+
def get_test_examples(self, data_dir):
|
350 |
+
"""See base class."""
|
351 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
352 |
+
|
353 |
+
def get_labels(self):
|
354 |
+
"""See base class."""
|
355 |
+
return ["0", "1"]
|
356 |
+
|
357 |
+
def _create_examples(self, lines, set_type):
|
358 |
+
"""Creates examples for the training, dev and test sets."""
|
359 |
+
examples = []
|
360 |
+
text_index = 1 if set_type == "test" else 0
|
361 |
+
for i, line in enumerate(lines):
|
362 |
+
if i == 0:
|
363 |
+
continue
|
364 |
+
guid = f"{set_type}-{i}"
|
365 |
+
text_a = line[text_index]
|
366 |
+
label = None if set_type == "test" else line[1]
|
367 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
368 |
+
return examples
|
369 |
+
|
370 |
+
|
371 |
+
class StsbProcessor(DataProcessor):
|
372 |
+
"""Processor for the STS-B data set (GLUE version)."""
|
373 |
+
|
374 |
+
def __init__(self, *args, **kwargs):
|
375 |
+
super().__init__(*args, **kwargs)
|
376 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
377 |
+
|
378 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
379 |
+
"""See base class."""
|
380 |
+
return InputExample(
|
381 |
+
tensor_dict["idx"].numpy(),
|
382 |
+
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
383 |
+
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
384 |
+
str(tensor_dict["label"].numpy()),
|
385 |
+
)
|
386 |
+
|
387 |
+
def get_train_examples(self, data_dir):
|
388 |
+
"""See base class."""
|
389 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
390 |
+
|
391 |
+
def get_dev_examples(self, data_dir):
|
392 |
+
"""See base class."""
|
393 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
394 |
+
|
395 |
+
def get_test_examples(self, data_dir):
|
396 |
+
"""See base class."""
|
397 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
398 |
+
|
399 |
+
def get_labels(self):
|
400 |
+
"""See base class."""
|
401 |
+
return [None]
|
402 |
+
|
403 |
+
def _create_examples(self, lines, set_type):
|
404 |
+
"""Creates examples for the training, dev and test sets."""
|
405 |
+
examples = []
|
406 |
+
for i, line in enumerate(lines):
|
407 |
+
if i == 0:
|
408 |
+
continue
|
409 |
+
guid = f"{set_type}-{line[0]}"
|
410 |
+
text_a = line[7]
|
411 |
+
text_b = line[8]
|
412 |
+
label = None if set_type == "test" else line[-1]
|
413 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
414 |
+
return examples
|
415 |
+
|
416 |
+
|
417 |
+
class QqpProcessor(DataProcessor):
|
418 |
+
"""Processor for the QQP data set (GLUE version)."""
|
419 |
+
|
420 |
+
def __init__(self, *args, **kwargs):
|
421 |
+
super().__init__(*args, **kwargs)
|
422 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
423 |
+
|
424 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
425 |
+
"""See base class."""
|
426 |
+
return InputExample(
|
427 |
+
tensor_dict["idx"].numpy(),
|
428 |
+
tensor_dict["question1"].numpy().decode("utf-8"),
|
429 |
+
tensor_dict["question2"].numpy().decode("utf-8"),
|
430 |
+
str(tensor_dict["label"].numpy()),
|
431 |
+
)
|
432 |
+
|
433 |
+
def get_train_examples(self, data_dir):
|
434 |
+
"""See base class."""
|
435 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
436 |
+
|
437 |
+
def get_dev_examples(self, data_dir):
|
438 |
+
"""See base class."""
|
439 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
440 |
+
|
441 |
+
def get_test_examples(self, data_dir):
|
442 |
+
"""See base class."""
|
443 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
444 |
+
|
445 |
+
def get_labels(self):
|
446 |
+
"""See base class."""
|
447 |
+
return ["0", "1"]
|
448 |
+
|
449 |
+
def _create_examples(self, lines, set_type):
|
450 |
+
"""Creates examples for the training, dev and test sets."""
|
451 |
+
test_mode = set_type == "test"
|
452 |
+
q1_index = 1 if test_mode else 3
|
453 |
+
q2_index = 2 if test_mode else 4
|
454 |
+
examples = []
|
455 |
+
for i, line in enumerate(lines):
|
456 |
+
if i == 0:
|
457 |
+
continue
|
458 |
+
guid = f"{set_type}-{line[0]}"
|
459 |
+
try:
|
460 |
+
text_a = line[q1_index]
|
461 |
+
text_b = line[q2_index]
|
462 |
+
label = None if test_mode else line[5]
|
463 |
+
except IndexError:
|
464 |
+
continue
|
465 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
466 |
+
return examples
|
467 |
+
|
468 |
+
|
469 |
+
class QnliProcessor(DataProcessor):
|
470 |
+
"""Processor for the QNLI data set (GLUE version)."""
|
471 |
+
|
472 |
+
def __init__(self, *args, **kwargs):
|
473 |
+
super().__init__(*args, **kwargs)
|
474 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
475 |
+
|
476 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
477 |
+
"""See base class."""
|
478 |
+
return InputExample(
|
479 |
+
tensor_dict["idx"].numpy(),
|
480 |
+
tensor_dict["question"].numpy().decode("utf-8"),
|
481 |
+
tensor_dict["sentence"].numpy().decode("utf-8"),
|
482 |
+
str(tensor_dict["label"].numpy()),
|
483 |
+
)
|
484 |
+
|
485 |
+
def get_train_examples(self, data_dir):
|
486 |
+
"""See base class."""
|
487 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
488 |
+
|
489 |
+
def get_dev_examples(self, data_dir):
|
490 |
+
"""See base class."""
|
491 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
492 |
+
|
493 |
+
def get_test_examples(self, data_dir):
|
494 |
+
"""See base class."""
|
495 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
496 |
+
|
497 |
+
def get_labels(self):
|
498 |
+
"""See base class."""
|
499 |
+
return ["entailment", "not_entailment"]
|
500 |
+
|
501 |
+
def _create_examples(self, lines, set_type):
|
502 |
+
"""Creates examples for the training, dev and test sets."""
|
503 |
+
examples = []
|
504 |
+
for i, line in enumerate(lines):
|
505 |
+
if i == 0:
|
506 |
+
continue
|
507 |
+
guid = f"{set_type}-{line[0]}"
|
508 |
+
text_a = line[1]
|
509 |
+
text_b = line[2]
|
510 |
+
label = None if set_type == "test" else line[-1]
|
511 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
512 |
+
return examples
|
513 |
+
|
514 |
+
|
515 |
+
class RteProcessor(DataProcessor):
|
516 |
+
"""Processor for the RTE data set (GLUE version)."""
|
517 |
+
|
518 |
+
def __init__(self, *args, **kwargs):
|
519 |
+
super().__init__(*args, **kwargs)
|
520 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
521 |
+
|
522 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
523 |
+
"""See base class."""
|
524 |
+
return InputExample(
|
525 |
+
tensor_dict["idx"].numpy(),
|
526 |
+
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
527 |
+
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
528 |
+
str(tensor_dict["label"].numpy()),
|
529 |
+
)
|
530 |
+
|
531 |
+
def get_train_examples(self, data_dir):
|
532 |
+
"""See base class."""
|
533 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
534 |
+
|
535 |
+
def get_dev_examples(self, data_dir):
|
536 |
+
"""See base class."""
|
537 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
538 |
+
|
539 |
+
def get_test_examples(self, data_dir):
|
540 |
+
"""See base class."""
|
541 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
542 |
+
|
543 |
+
def get_labels(self):
|
544 |
+
"""See base class."""
|
545 |
+
return ["entailment", "not_entailment"]
|
546 |
+
|
547 |
+
def _create_examples(self, lines, set_type):
|
548 |
+
"""Creates examples for the training, dev and test sets."""
|
549 |
+
examples = []
|
550 |
+
for i, line in enumerate(lines):
|
551 |
+
if i == 0:
|
552 |
+
continue
|
553 |
+
guid = f"{set_type}-{line[0]}"
|
554 |
+
text_a = line[1]
|
555 |
+
text_b = line[2]
|
556 |
+
label = None if set_type == "test" else line[-1]
|
557 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
558 |
+
return examples
|
559 |
+
|
560 |
+
|
561 |
+
class WnliProcessor(DataProcessor):
|
562 |
+
"""Processor for the WNLI data set (GLUE version)."""
|
563 |
+
|
564 |
+
def __init__(self, *args, **kwargs):
|
565 |
+
super().__init__(*args, **kwargs)
|
566 |
+
warnings.warn(DEPRECATION_WARNING.format("processor"), FutureWarning)
|
567 |
+
|
568 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
569 |
+
"""See base class."""
|
570 |
+
return InputExample(
|
571 |
+
tensor_dict["idx"].numpy(),
|
572 |
+
tensor_dict["sentence1"].numpy().decode("utf-8"),
|
573 |
+
tensor_dict["sentence2"].numpy().decode("utf-8"),
|
574 |
+
str(tensor_dict["label"].numpy()),
|
575 |
+
)
|
576 |
+
|
577 |
+
def get_train_examples(self, data_dir):
|
578 |
+
"""See base class."""
|
579 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
580 |
+
|
581 |
+
def get_dev_examples(self, data_dir):
|
582 |
+
"""See base class."""
|
583 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
584 |
+
|
585 |
+
def get_test_examples(self, data_dir):
|
586 |
+
"""See base class."""
|
587 |
+
return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
|
588 |
+
|
589 |
+
def get_labels(self):
|
590 |
+
"""See base class."""
|
591 |
+
return ["0", "1"]
|
592 |
+
|
593 |
+
def _create_examples(self, lines, set_type):
|
594 |
+
"""Creates examples for the training, dev and test sets."""
|
595 |
+
examples = []
|
596 |
+
for i, line in enumerate(lines):
|
597 |
+
if i == 0:
|
598 |
+
continue
|
599 |
+
guid = f"{set_type}-{line[0]}"
|
600 |
+
text_a = line[1]
|
601 |
+
text_b = line[2]
|
602 |
+
label = None if set_type == "test" else line[-1]
|
603 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
604 |
+
return examples
|
605 |
+
|
606 |
+
|
607 |
+
glue_tasks_num_labels = {
|
608 |
+
"cola": 2,
|
609 |
+
"mnli": 3,
|
610 |
+
"mrpc": 2,
|
611 |
+
"sst-2": 2,
|
612 |
+
"sts-b": 1,
|
613 |
+
"qqp": 2,
|
614 |
+
"qnli": 2,
|
615 |
+
"rte": 2,
|
616 |
+
"wnli": 2,
|
617 |
+
}
|
618 |
+
|
619 |
+
glue_processors = {
|
620 |
+
"cola": ColaProcessor,
|
621 |
+
"mnli": MnliProcessor,
|
622 |
+
"mnli-mm": MnliMismatchedProcessor,
|
623 |
+
"mrpc": MrpcProcessor,
|
624 |
+
"sst-2": Sst2Processor,
|
625 |
+
"sts-b": StsbProcessor,
|
626 |
+
"qqp": QqpProcessor,
|
627 |
+
"qnli": QnliProcessor,
|
628 |
+
"rte": RteProcessor,
|
629 |
+
"wnli": WnliProcessor,
|
630 |
+
}
|
631 |
+
|
632 |
+
glue_output_modes = {
|
633 |
+
"cola": "classification",
|
634 |
+
"mnli": "classification",
|
635 |
+
"mnli-mm": "classification",
|
636 |
+
"mrpc": "classification",
|
637 |
+
"sst-2": "classification",
|
638 |
+
"sts-b": "regression",
|
639 |
+
"qqp": "classification",
|
640 |
+
"qnli": "classification",
|
641 |
+
"rte": "classification",
|
642 |
+
"wnli": "classification",
|
643 |
+
}
|
transformers_4_35_0/data/processors/squad.py
ADDED
@@ -0,0 +1,845 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
from functools import partial
|
18 |
+
from multiprocessing import Pool, cpu_count
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
from ...models.bert.tokenization_bert import whitespace_tokenize
|
24 |
+
from ...tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase, TruncationStrategy
|
25 |
+
from ...utils import is_tf_available, is_torch_available, logging
|
26 |
+
from .utils import DataProcessor
|
27 |
+
|
28 |
+
|
29 |
+
# Store the tokenizers which insert 2 separators tokens
|
30 |
+
MULTI_SEP_TOKENS_TOKENIZERS_SET = {"roberta", "camembert", "bart", "mpnet"}
|
31 |
+
|
32 |
+
|
33 |
+
if is_torch_available():
|
34 |
+
import torch
|
35 |
+
from torch.utils.data import TensorDataset
|
36 |
+
|
37 |
+
if is_tf_available():
|
38 |
+
import tensorflow as tf
|
39 |
+
|
40 |
+
logger = logging.get_logger(__name__)
|
41 |
+
|
42 |
+
|
43 |
+
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
|
44 |
+
"""Returns tokenized answer spans that better match the annotated answer."""
|
45 |
+
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
|
46 |
+
|
47 |
+
for new_start in range(input_start, input_end + 1):
|
48 |
+
for new_end in range(input_end, new_start - 1, -1):
|
49 |
+
text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
|
50 |
+
if text_span == tok_answer_text:
|
51 |
+
return (new_start, new_end)
|
52 |
+
|
53 |
+
return (input_start, input_end)
|
54 |
+
|
55 |
+
|
56 |
+
def _check_is_max_context(doc_spans, cur_span_index, position):
|
57 |
+
"""Check if this is the 'max context' doc span for the token."""
|
58 |
+
best_score = None
|
59 |
+
best_span_index = None
|
60 |
+
for span_index, doc_span in enumerate(doc_spans):
|
61 |
+
end = doc_span.start + doc_span.length - 1
|
62 |
+
if position < doc_span.start:
|
63 |
+
continue
|
64 |
+
if position > end:
|
65 |
+
continue
|
66 |
+
num_left_context = position - doc_span.start
|
67 |
+
num_right_context = end - position
|
68 |
+
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
|
69 |
+
if best_score is None or score > best_score:
|
70 |
+
best_score = score
|
71 |
+
best_span_index = span_index
|
72 |
+
|
73 |
+
return cur_span_index == best_span_index
|
74 |
+
|
75 |
+
|
76 |
+
def _new_check_is_max_context(doc_spans, cur_span_index, position):
|
77 |
+
"""Check if this is the 'max context' doc span for the token."""
|
78 |
+
# if len(doc_spans) == 1:
|
79 |
+
# return True
|
80 |
+
best_score = None
|
81 |
+
best_span_index = None
|
82 |
+
for span_index, doc_span in enumerate(doc_spans):
|
83 |
+
end = doc_span["start"] + doc_span["length"] - 1
|
84 |
+
if position < doc_span["start"]:
|
85 |
+
continue
|
86 |
+
if position > end:
|
87 |
+
continue
|
88 |
+
num_left_context = position - doc_span["start"]
|
89 |
+
num_right_context = end - position
|
90 |
+
score = min(num_left_context, num_right_context) + 0.01 * doc_span["length"]
|
91 |
+
if best_score is None or score > best_score:
|
92 |
+
best_score = score
|
93 |
+
best_span_index = span_index
|
94 |
+
|
95 |
+
return cur_span_index == best_span_index
|
96 |
+
|
97 |
+
|
98 |
+
def _is_whitespace(c):
|
99 |
+
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
|
100 |
+
return True
|
101 |
+
return False
|
102 |
+
|
103 |
+
|
104 |
+
def squad_convert_example_to_features(
|
105 |
+
example, max_seq_length, doc_stride, max_query_length, padding_strategy, is_training
|
106 |
+
):
|
107 |
+
features = []
|
108 |
+
if is_training and not example.is_impossible:
|
109 |
+
# Get start and end position
|
110 |
+
start_position = example.start_position
|
111 |
+
end_position = example.end_position
|
112 |
+
|
113 |
+
# If the answer cannot be found in the text, then skip this example.
|
114 |
+
actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)])
|
115 |
+
cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text))
|
116 |
+
if actual_text.find(cleaned_answer_text) == -1:
|
117 |
+
logger.warning(f"Could not find answer: '{actual_text}' vs. '{cleaned_answer_text}'")
|
118 |
+
return []
|
119 |
+
|
120 |
+
tok_to_orig_index = []
|
121 |
+
orig_to_tok_index = []
|
122 |
+
all_doc_tokens = []
|
123 |
+
for i, token in enumerate(example.doc_tokens):
|
124 |
+
orig_to_tok_index.append(len(all_doc_tokens))
|
125 |
+
if tokenizer.__class__.__name__ in [
|
126 |
+
"RobertaTokenizer",
|
127 |
+
"LongformerTokenizer",
|
128 |
+
"BartTokenizer",
|
129 |
+
"RobertaTokenizerFast",
|
130 |
+
"LongformerTokenizerFast",
|
131 |
+
"BartTokenizerFast",
|
132 |
+
]:
|
133 |
+
sub_tokens = tokenizer.tokenize(token, add_prefix_space=True)
|
134 |
+
else:
|
135 |
+
sub_tokens = tokenizer.tokenize(token)
|
136 |
+
for sub_token in sub_tokens:
|
137 |
+
tok_to_orig_index.append(i)
|
138 |
+
all_doc_tokens.append(sub_token)
|
139 |
+
|
140 |
+
if is_training and not example.is_impossible:
|
141 |
+
tok_start_position = orig_to_tok_index[example.start_position]
|
142 |
+
if example.end_position < len(example.doc_tokens) - 1:
|
143 |
+
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
|
144 |
+
else:
|
145 |
+
tok_end_position = len(all_doc_tokens) - 1
|
146 |
+
|
147 |
+
(tok_start_position, tok_end_position) = _improve_answer_span(
|
148 |
+
all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text
|
149 |
+
)
|
150 |
+
|
151 |
+
spans = []
|
152 |
+
|
153 |
+
truncated_query = tokenizer.encode(
|
154 |
+
example.question_text, add_special_tokens=False, truncation=True, max_length=max_query_length
|
155 |
+
)
|
156 |
+
|
157 |
+
# Tokenizers who insert 2 SEP tokens in-between <context> & <question> need to have special handling
|
158 |
+
# in the way they compute mask of added tokens.
|
159 |
+
tokenizer_type = type(tokenizer).__name__.replace("Tokenizer", "").lower()
|
160 |
+
sequence_added_tokens = (
|
161 |
+
tokenizer.model_max_length - tokenizer.max_len_single_sentence + 1
|
162 |
+
if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET
|
163 |
+
else tokenizer.model_max_length - tokenizer.max_len_single_sentence
|
164 |
+
)
|
165 |
+
sequence_pair_added_tokens = tokenizer.model_max_length - tokenizer.max_len_sentences_pair
|
166 |
+
|
167 |
+
span_doc_tokens = all_doc_tokens
|
168 |
+
while len(spans) * doc_stride < len(all_doc_tokens):
|
169 |
+
# Define the side we want to truncate / pad and the text/pair sorting
|
170 |
+
if tokenizer.padding_side == "right":
|
171 |
+
texts = truncated_query
|
172 |
+
pairs = span_doc_tokens
|
173 |
+
truncation = TruncationStrategy.ONLY_SECOND.value
|
174 |
+
else:
|
175 |
+
texts = span_doc_tokens
|
176 |
+
pairs = truncated_query
|
177 |
+
truncation = TruncationStrategy.ONLY_FIRST.value
|
178 |
+
|
179 |
+
encoded_dict = tokenizer.encode_plus( # TODO(thom) update this logic
|
180 |
+
texts,
|
181 |
+
pairs,
|
182 |
+
truncation=truncation,
|
183 |
+
padding=padding_strategy,
|
184 |
+
max_length=max_seq_length,
|
185 |
+
return_overflowing_tokens=True,
|
186 |
+
stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
|
187 |
+
return_token_type_ids=True,
|
188 |
+
)
|
189 |
+
|
190 |
+
paragraph_len = min(
|
191 |
+
len(all_doc_tokens) - len(spans) * doc_stride,
|
192 |
+
max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
|
193 |
+
)
|
194 |
+
|
195 |
+
if tokenizer.pad_token_id in encoded_dict["input_ids"]:
|
196 |
+
if tokenizer.padding_side == "right":
|
197 |
+
non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)]
|
198 |
+
else:
|
199 |
+
last_padding_id_position = (
|
200 |
+
len(encoded_dict["input_ids"]) - 1 - encoded_dict["input_ids"][::-1].index(tokenizer.pad_token_id)
|
201 |
+
)
|
202 |
+
non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1 :]
|
203 |
+
|
204 |
+
else:
|
205 |
+
non_padded_ids = encoded_dict["input_ids"]
|
206 |
+
|
207 |
+
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
|
208 |
+
|
209 |
+
token_to_orig_map = {}
|
210 |
+
for i in range(paragraph_len):
|
211 |
+
index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i
|
212 |
+
token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i]
|
213 |
+
|
214 |
+
encoded_dict["paragraph_len"] = paragraph_len
|
215 |
+
encoded_dict["tokens"] = tokens
|
216 |
+
encoded_dict["token_to_orig_map"] = token_to_orig_map
|
217 |
+
encoded_dict["truncated_query_with_special_tokens_length"] = len(truncated_query) + sequence_added_tokens
|
218 |
+
encoded_dict["token_is_max_context"] = {}
|
219 |
+
encoded_dict["start"] = len(spans) * doc_stride
|
220 |
+
encoded_dict["length"] = paragraph_len
|
221 |
+
|
222 |
+
spans.append(encoded_dict)
|
223 |
+
|
224 |
+
if "overflowing_tokens" not in encoded_dict or (
|
225 |
+
"overflowing_tokens" in encoded_dict and len(encoded_dict["overflowing_tokens"]) == 0
|
226 |
+
):
|
227 |
+
break
|
228 |
+
span_doc_tokens = encoded_dict["overflowing_tokens"]
|
229 |
+
|
230 |
+
for doc_span_index in range(len(spans)):
|
231 |
+
for j in range(spans[doc_span_index]["paragraph_len"]):
|
232 |
+
is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j)
|
233 |
+
index = (
|
234 |
+
j
|
235 |
+
if tokenizer.padding_side == "left"
|
236 |
+
else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j
|
237 |
+
)
|
238 |
+
spans[doc_span_index]["token_is_max_context"][index] = is_max_context
|
239 |
+
|
240 |
+
for span in spans:
|
241 |
+
# Identify the position of the CLS token
|
242 |
+
cls_index = span["input_ids"].index(tokenizer.cls_token_id)
|
243 |
+
|
244 |
+
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
|
245 |
+
# Original TF implementation also keep the classification token (set to 0)
|
246 |
+
p_mask = np.ones_like(span["token_type_ids"])
|
247 |
+
if tokenizer.padding_side == "right":
|
248 |
+
p_mask[len(truncated_query) + sequence_added_tokens :] = 0
|
249 |
+
else:
|
250 |
+
p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0
|
251 |
+
|
252 |
+
pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id)
|
253 |
+
special_token_indices = np.asarray(
|
254 |
+
tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True)
|
255 |
+
).nonzero()
|
256 |
+
|
257 |
+
p_mask[pad_token_indices] = 1
|
258 |
+
p_mask[special_token_indices] = 1
|
259 |
+
|
260 |
+
# Set the cls index to 0: the CLS index can be used for impossible answers
|
261 |
+
p_mask[cls_index] = 0
|
262 |
+
|
263 |
+
span_is_impossible = example.is_impossible
|
264 |
+
start_position = 0
|
265 |
+
end_position = 0
|
266 |
+
if is_training and not span_is_impossible:
|
267 |
+
# For training, if our document chunk does not contain an annotation
|
268 |
+
# we throw it out, since there is nothing to predict.
|
269 |
+
doc_start = span["start"]
|
270 |
+
doc_end = span["start"] + span["length"] - 1
|
271 |
+
out_of_span = False
|
272 |
+
|
273 |
+
if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
|
274 |
+
out_of_span = True
|
275 |
+
|
276 |
+
if out_of_span:
|
277 |
+
start_position = cls_index
|
278 |
+
end_position = cls_index
|
279 |
+
span_is_impossible = True
|
280 |
+
else:
|
281 |
+
if tokenizer.padding_side == "left":
|
282 |
+
doc_offset = 0
|
283 |
+
else:
|
284 |
+
doc_offset = len(truncated_query) + sequence_added_tokens
|
285 |
+
|
286 |
+
start_position = tok_start_position - doc_start + doc_offset
|
287 |
+
end_position = tok_end_position - doc_start + doc_offset
|
288 |
+
|
289 |
+
features.append(
|
290 |
+
SquadFeatures(
|
291 |
+
span["input_ids"],
|
292 |
+
span["attention_mask"],
|
293 |
+
span["token_type_ids"],
|
294 |
+
cls_index,
|
295 |
+
p_mask.tolist(),
|
296 |
+
example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing.
|
297 |
+
unique_id=0,
|
298 |
+
paragraph_len=span["paragraph_len"],
|
299 |
+
token_is_max_context=span["token_is_max_context"],
|
300 |
+
tokens=span["tokens"],
|
301 |
+
token_to_orig_map=span["token_to_orig_map"],
|
302 |
+
start_position=start_position,
|
303 |
+
end_position=end_position,
|
304 |
+
is_impossible=span_is_impossible,
|
305 |
+
qas_id=example.qas_id,
|
306 |
+
)
|
307 |
+
)
|
308 |
+
return features
|
309 |
+
|
310 |
+
|
311 |
+
def squad_convert_example_to_features_init(tokenizer_for_convert: PreTrainedTokenizerBase):
|
312 |
+
global tokenizer
|
313 |
+
tokenizer = tokenizer_for_convert
|
314 |
+
|
315 |
+
|
316 |
+
def squad_convert_examples_to_features(
|
317 |
+
examples,
|
318 |
+
tokenizer,
|
319 |
+
max_seq_length,
|
320 |
+
doc_stride,
|
321 |
+
max_query_length,
|
322 |
+
is_training,
|
323 |
+
padding_strategy="max_length",
|
324 |
+
return_dataset=False,
|
325 |
+
threads=1,
|
326 |
+
tqdm_enabled=True,
|
327 |
+
):
|
328 |
+
"""
|
329 |
+
Converts a list of examples into a list of features that can be directly given as input to a model. It is
|
330 |
+
model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
|
331 |
+
|
332 |
+
Args:
|
333 |
+
examples: list of [`~data.processors.squad.SquadExample`]
|
334 |
+
tokenizer: an instance of a child of [`PreTrainedTokenizer`]
|
335 |
+
max_seq_length: The maximum sequence length of the inputs.
|
336 |
+
doc_stride: The stride used when the context is too large and is split across several features.
|
337 |
+
max_query_length: The maximum length of the query.
|
338 |
+
is_training: whether to create features for model evaluation or model training.
|
339 |
+
padding_strategy: Default to "max_length". Which padding strategy to use
|
340 |
+
return_dataset: Default False. Either 'pt' or 'tf'.
|
341 |
+
if 'pt': returns a torch.data.TensorDataset, if 'tf': returns a tf.data.Dataset
|
342 |
+
threads: multiple processing threads.
|
343 |
+
|
344 |
+
|
345 |
+
Returns:
|
346 |
+
list of [`~data.processors.squad.SquadFeatures`]
|
347 |
+
|
348 |
+
Example:
|
349 |
+
|
350 |
+
```python
|
351 |
+
processor = SquadV2Processor()
|
352 |
+
examples = processor.get_dev_examples(data_dir)
|
353 |
+
|
354 |
+
features = squad_convert_examples_to_features(
|
355 |
+
examples=examples,
|
356 |
+
tokenizer=tokenizer,
|
357 |
+
max_seq_length=args.max_seq_length,
|
358 |
+
doc_stride=args.doc_stride,
|
359 |
+
max_query_length=args.max_query_length,
|
360 |
+
is_training=not evaluate,
|
361 |
+
)
|
362 |
+
```"""
|
363 |
+
# Defining helper methods
|
364 |
+
features = []
|
365 |
+
|
366 |
+
threads = min(threads, cpu_count())
|
367 |
+
with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p:
|
368 |
+
annotate_ = partial(
|
369 |
+
squad_convert_example_to_features,
|
370 |
+
max_seq_length=max_seq_length,
|
371 |
+
doc_stride=doc_stride,
|
372 |
+
max_query_length=max_query_length,
|
373 |
+
padding_strategy=padding_strategy,
|
374 |
+
is_training=is_training,
|
375 |
+
)
|
376 |
+
features = list(
|
377 |
+
tqdm(
|
378 |
+
p.imap(annotate_, examples, chunksize=32),
|
379 |
+
total=len(examples),
|
380 |
+
desc="convert squad examples to features",
|
381 |
+
disable=not tqdm_enabled,
|
382 |
+
)
|
383 |
+
)
|
384 |
+
|
385 |
+
new_features = []
|
386 |
+
unique_id = 1000000000
|
387 |
+
example_index = 0
|
388 |
+
for example_features in tqdm(
|
389 |
+
features, total=len(features), desc="add example index and unique id", disable=not tqdm_enabled
|
390 |
+
):
|
391 |
+
if not example_features:
|
392 |
+
continue
|
393 |
+
for example_feature in example_features:
|
394 |
+
example_feature.example_index = example_index
|
395 |
+
example_feature.unique_id = unique_id
|
396 |
+
new_features.append(example_feature)
|
397 |
+
unique_id += 1
|
398 |
+
example_index += 1
|
399 |
+
features = new_features
|
400 |
+
del new_features
|
401 |
+
if return_dataset == "pt":
|
402 |
+
if not is_torch_available():
|
403 |
+
raise RuntimeError("PyTorch must be installed to return a PyTorch dataset.")
|
404 |
+
|
405 |
+
# Convert to Tensors and build dataset
|
406 |
+
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
407 |
+
all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
408 |
+
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
409 |
+
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
|
410 |
+
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
411 |
+
all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float)
|
412 |
+
|
413 |
+
if not is_training:
|
414 |
+
all_feature_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
415 |
+
dataset = TensorDataset(
|
416 |
+
all_input_ids, all_attention_masks, all_token_type_ids, all_feature_index, all_cls_index, all_p_mask
|
417 |
+
)
|
418 |
+
else:
|
419 |
+
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
420 |
+
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
421 |
+
dataset = TensorDataset(
|
422 |
+
all_input_ids,
|
423 |
+
all_attention_masks,
|
424 |
+
all_token_type_ids,
|
425 |
+
all_start_positions,
|
426 |
+
all_end_positions,
|
427 |
+
all_cls_index,
|
428 |
+
all_p_mask,
|
429 |
+
all_is_impossible,
|
430 |
+
)
|
431 |
+
|
432 |
+
return features, dataset
|
433 |
+
elif return_dataset == "tf":
|
434 |
+
if not is_tf_available():
|
435 |
+
raise RuntimeError("TensorFlow must be installed to return a TensorFlow dataset.")
|
436 |
+
|
437 |
+
def gen():
|
438 |
+
for i, ex in enumerate(features):
|
439 |
+
if ex.token_type_ids is None:
|
440 |
+
yield (
|
441 |
+
{
|
442 |
+
"input_ids": ex.input_ids,
|
443 |
+
"attention_mask": ex.attention_mask,
|
444 |
+
"feature_index": i,
|
445 |
+
"qas_id": ex.qas_id,
|
446 |
+
},
|
447 |
+
{
|
448 |
+
"start_positions": ex.start_position,
|
449 |
+
"end_positions": ex.end_position,
|
450 |
+
"cls_index": ex.cls_index,
|
451 |
+
"p_mask": ex.p_mask,
|
452 |
+
"is_impossible": ex.is_impossible,
|
453 |
+
},
|
454 |
+
)
|
455 |
+
else:
|
456 |
+
yield (
|
457 |
+
{
|
458 |
+
"input_ids": ex.input_ids,
|
459 |
+
"attention_mask": ex.attention_mask,
|
460 |
+
"token_type_ids": ex.token_type_ids,
|
461 |
+
"feature_index": i,
|
462 |
+
"qas_id": ex.qas_id,
|
463 |
+
},
|
464 |
+
{
|
465 |
+
"start_positions": ex.start_position,
|
466 |
+
"end_positions": ex.end_position,
|
467 |
+
"cls_index": ex.cls_index,
|
468 |
+
"p_mask": ex.p_mask,
|
469 |
+
"is_impossible": ex.is_impossible,
|
470 |
+
},
|
471 |
+
)
|
472 |
+
|
473 |
+
# Why have we split the batch into a tuple? PyTorch just has a list of tensors.
|
474 |
+
if "token_type_ids" in tokenizer.model_input_names:
|
475 |
+
train_types = (
|
476 |
+
{
|
477 |
+
"input_ids": tf.int32,
|
478 |
+
"attention_mask": tf.int32,
|
479 |
+
"token_type_ids": tf.int32,
|
480 |
+
"feature_index": tf.int64,
|
481 |
+
"qas_id": tf.string,
|
482 |
+
},
|
483 |
+
{
|
484 |
+
"start_positions": tf.int64,
|
485 |
+
"end_positions": tf.int64,
|
486 |
+
"cls_index": tf.int64,
|
487 |
+
"p_mask": tf.int32,
|
488 |
+
"is_impossible": tf.int32,
|
489 |
+
},
|
490 |
+
)
|
491 |
+
|
492 |
+
train_shapes = (
|
493 |
+
{
|
494 |
+
"input_ids": tf.TensorShape([None]),
|
495 |
+
"attention_mask": tf.TensorShape([None]),
|
496 |
+
"token_type_ids": tf.TensorShape([None]),
|
497 |
+
"feature_index": tf.TensorShape([]),
|
498 |
+
"qas_id": tf.TensorShape([]),
|
499 |
+
},
|
500 |
+
{
|
501 |
+
"start_positions": tf.TensorShape([]),
|
502 |
+
"end_positions": tf.TensorShape([]),
|
503 |
+
"cls_index": tf.TensorShape([]),
|
504 |
+
"p_mask": tf.TensorShape([None]),
|
505 |
+
"is_impossible": tf.TensorShape([]),
|
506 |
+
},
|
507 |
+
)
|
508 |
+
else:
|
509 |
+
train_types = (
|
510 |
+
{"input_ids": tf.int32, "attention_mask": tf.int32, "feature_index": tf.int64, "qas_id": tf.string},
|
511 |
+
{
|
512 |
+
"start_positions": tf.int64,
|
513 |
+
"end_positions": tf.int64,
|
514 |
+
"cls_index": tf.int64,
|
515 |
+
"p_mask": tf.int32,
|
516 |
+
"is_impossible": tf.int32,
|
517 |
+
},
|
518 |
+
)
|
519 |
+
|
520 |
+
train_shapes = (
|
521 |
+
{
|
522 |
+
"input_ids": tf.TensorShape([None]),
|
523 |
+
"attention_mask": tf.TensorShape([None]),
|
524 |
+
"feature_index": tf.TensorShape([]),
|
525 |
+
"qas_id": tf.TensorShape([]),
|
526 |
+
},
|
527 |
+
{
|
528 |
+
"start_positions": tf.TensorShape([]),
|
529 |
+
"end_positions": tf.TensorShape([]),
|
530 |
+
"cls_index": tf.TensorShape([]),
|
531 |
+
"p_mask": tf.TensorShape([None]),
|
532 |
+
"is_impossible": tf.TensorShape([]),
|
533 |
+
},
|
534 |
+
)
|
535 |
+
|
536 |
+
return tf.data.Dataset.from_generator(gen, train_types, train_shapes)
|
537 |
+
else:
|
538 |
+
return features
|
539 |
+
|
540 |
+
|
541 |
+
class SquadProcessor(DataProcessor):
|
542 |
+
"""
|
543 |
+
Processor for the SQuAD data set. overridden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and
|
544 |
+
version 2.0 of SQuAD, respectively.
|
545 |
+
"""
|
546 |
+
|
547 |
+
train_file = None
|
548 |
+
dev_file = None
|
549 |
+
|
550 |
+
def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False):
|
551 |
+
if not evaluate:
|
552 |
+
answer = tensor_dict["answers"]["text"][0].numpy().decode("utf-8")
|
553 |
+
answer_start = tensor_dict["answers"]["answer_start"][0].numpy()
|
554 |
+
answers = []
|
555 |
+
else:
|
556 |
+
answers = [
|
557 |
+
{"answer_start": start.numpy(), "text": text.numpy().decode("utf-8")}
|
558 |
+
for start, text in zip(tensor_dict["answers"]["answer_start"], tensor_dict["answers"]["text"])
|
559 |
+
]
|
560 |
+
|
561 |
+
answer = None
|
562 |
+
answer_start = None
|
563 |
+
|
564 |
+
return SquadExample(
|
565 |
+
qas_id=tensor_dict["id"].numpy().decode("utf-8"),
|
566 |
+
question_text=tensor_dict["question"].numpy().decode("utf-8"),
|
567 |
+
context_text=tensor_dict["context"].numpy().decode("utf-8"),
|
568 |
+
answer_text=answer,
|
569 |
+
start_position_character=answer_start,
|
570 |
+
title=tensor_dict["title"].numpy().decode("utf-8"),
|
571 |
+
answers=answers,
|
572 |
+
)
|
573 |
+
|
574 |
+
def get_examples_from_dataset(self, dataset, evaluate=False):
|
575 |
+
"""
|
576 |
+
Creates a list of [`~data.processors.squad.SquadExample`] using a TFDS dataset.
|
577 |
+
|
578 |
+
Args:
|
579 |
+
dataset: The tfds dataset loaded from *tensorflow_datasets.load("squad")*
|
580 |
+
evaluate: Boolean specifying if in evaluation mode or in training mode
|
581 |
+
|
582 |
+
Returns:
|
583 |
+
List of SquadExample
|
584 |
+
|
585 |
+
Examples:
|
586 |
+
|
587 |
+
```python
|
588 |
+
>>> import tensorflow_datasets as tfds
|
589 |
+
|
590 |
+
>>> dataset = tfds.load("squad")
|
591 |
+
|
592 |
+
>>> training_examples = get_examples_from_dataset(dataset, evaluate=False)
|
593 |
+
>>> evaluation_examples = get_examples_from_dataset(dataset, evaluate=True)
|
594 |
+
```"""
|
595 |
+
|
596 |
+
if evaluate:
|
597 |
+
dataset = dataset["validation"]
|
598 |
+
else:
|
599 |
+
dataset = dataset["train"]
|
600 |
+
|
601 |
+
examples = []
|
602 |
+
for tensor_dict in tqdm(dataset):
|
603 |
+
examples.append(self._get_example_from_tensor_dict(tensor_dict, evaluate=evaluate))
|
604 |
+
|
605 |
+
return examples
|
606 |
+
|
607 |
+
def get_train_examples(self, data_dir, filename=None):
|
608 |
+
"""
|
609 |
+
Returns the training examples from the data directory.
|
610 |
+
|
611 |
+
Args:
|
612 |
+
data_dir: Directory containing the data files used for training and evaluating.
|
613 |
+
filename: None by default, specify this if the training file has a different name than the original one
|
614 |
+
which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively.
|
615 |
+
|
616 |
+
"""
|
617 |
+
if data_dir is None:
|
618 |
+
data_dir = ""
|
619 |
+
|
620 |
+
if self.train_file is None:
|
621 |
+
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
|
622 |
+
|
623 |
+
with open(
|
624 |
+
os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8"
|
625 |
+
) as reader:
|
626 |
+
input_data = json.load(reader)["data"]
|
627 |
+
return self._create_examples(input_data, "train")
|
628 |
+
|
629 |
+
def get_dev_examples(self, data_dir, filename=None):
|
630 |
+
"""
|
631 |
+
Returns the evaluation example from the data directory.
|
632 |
+
|
633 |
+
Args:
|
634 |
+
data_dir: Directory containing the data files used for training and evaluating.
|
635 |
+
filename: None by default, specify this if the evaluation file has a different name than the original one
|
636 |
+
which is `dev-v1.1.json` and `dev-v2.0.json` for squad versions 1.1 and 2.0 respectively.
|
637 |
+
"""
|
638 |
+
if data_dir is None:
|
639 |
+
data_dir = ""
|
640 |
+
|
641 |
+
if self.dev_file is None:
|
642 |
+
raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")
|
643 |
+
|
644 |
+
with open(
|
645 |
+
os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8"
|
646 |
+
) as reader:
|
647 |
+
input_data = json.load(reader)["data"]
|
648 |
+
return self._create_examples(input_data, "dev")
|
649 |
+
|
650 |
+
def _create_examples(self, input_data, set_type):
|
651 |
+
is_training = set_type == "train"
|
652 |
+
examples = []
|
653 |
+
for entry in tqdm(input_data):
|
654 |
+
title = entry["title"]
|
655 |
+
for paragraph in entry["paragraphs"]:
|
656 |
+
context_text = paragraph["context"]
|
657 |
+
for qa in paragraph["qas"]:
|
658 |
+
qas_id = qa["id"]
|
659 |
+
question_text = qa["question"]
|
660 |
+
start_position_character = None
|
661 |
+
answer_text = None
|
662 |
+
answers = []
|
663 |
+
|
664 |
+
is_impossible = qa.get("is_impossible", False)
|
665 |
+
if not is_impossible:
|
666 |
+
if is_training:
|
667 |
+
answer = qa["answers"][0]
|
668 |
+
answer_text = answer["text"]
|
669 |
+
start_position_character = answer["answer_start"]
|
670 |
+
else:
|
671 |
+
answers = qa["answers"]
|
672 |
+
|
673 |
+
example = SquadExample(
|
674 |
+
qas_id=qas_id,
|
675 |
+
question_text=question_text,
|
676 |
+
context_text=context_text,
|
677 |
+
answer_text=answer_text,
|
678 |
+
start_position_character=start_position_character,
|
679 |
+
title=title,
|
680 |
+
is_impossible=is_impossible,
|
681 |
+
answers=answers,
|
682 |
+
)
|
683 |
+
examples.append(example)
|
684 |
+
return examples
|
685 |
+
|
686 |
+
|
687 |
+
class SquadV1Processor(SquadProcessor):
|
688 |
+
train_file = "train-v1.1.json"
|
689 |
+
dev_file = "dev-v1.1.json"
|
690 |
+
|
691 |
+
|
692 |
+
class SquadV2Processor(SquadProcessor):
|
693 |
+
train_file = "train-v2.0.json"
|
694 |
+
dev_file = "dev-v2.0.json"
|
695 |
+
|
696 |
+
|
697 |
+
class SquadExample:
|
698 |
+
"""
|
699 |
+
A single training/test example for the Squad dataset, as loaded from disk.
|
700 |
+
|
701 |
+
Args:
|
702 |
+
qas_id: The example's unique identifier
|
703 |
+
question_text: The question string
|
704 |
+
context_text: The context string
|
705 |
+
answer_text: The answer string
|
706 |
+
start_position_character: The character position of the start of the answer
|
707 |
+
title: The title of the example
|
708 |
+
answers: None by default, this is used during evaluation. Holds answers as well as their start positions.
|
709 |
+
is_impossible: False by default, set to True if the example has no possible answer.
|
710 |
+
"""
|
711 |
+
|
712 |
+
def __init__(
|
713 |
+
self,
|
714 |
+
qas_id,
|
715 |
+
question_text,
|
716 |
+
context_text,
|
717 |
+
answer_text,
|
718 |
+
start_position_character,
|
719 |
+
title,
|
720 |
+
answers=[],
|
721 |
+
is_impossible=False,
|
722 |
+
):
|
723 |
+
self.qas_id = qas_id
|
724 |
+
self.question_text = question_text
|
725 |
+
self.context_text = context_text
|
726 |
+
self.answer_text = answer_text
|
727 |
+
self.title = title
|
728 |
+
self.is_impossible = is_impossible
|
729 |
+
self.answers = answers
|
730 |
+
|
731 |
+
self.start_position, self.end_position = 0, 0
|
732 |
+
|
733 |
+
doc_tokens = []
|
734 |
+
char_to_word_offset = []
|
735 |
+
prev_is_whitespace = True
|
736 |
+
|
737 |
+
# Split on whitespace so that different tokens may be attributed to their original position.
|
738 |
+
for c in self.context_text:
|
739 |
+
if _is_whitespace(c):
|
740 |
+
prev_is_whitespace = True
|
741 |
+
else:
|
742 |
+
if prev_is_whitespace:
|
743 |
+
doc_tokens.append(c)
|
744 |
+
else:
|
745 |
+
doc_tokens[-1] += c
|
746 |
+
prev_is_whitespace = False
|
747 |
+
char_to_word_offset.append(len(doc_tokens) - 1)
|
748 |
+
|
749 |
+
self.doc_tokens = doc_tokens
|
750 |
+
self.char_to_word_offset = char_to_word_offset
|
751 |
+
|
752 |
+
# Start and end positions only has a value during evaluation.
|
753 |
+
if start_position_character is not None and not is_impossible:
|
754 |
+
self.start_position = char_to_word_offset[start_position_character]
|
755 |
+
self.end_position = char_to_word_offset[
|
756 |
+
min(start_position_character + len(answer_text) - 1, len(char_to_word_offset) - 1)
|
757 |
+
]
|
758 |
+
|
759 |
+
|
760 |
+
class SquadFeatures:
|
761 |
+
"""
|
762 |
+
Single squad example features to be fed to a model. Those features are model-specific and can be crafted from
|
763 |
+
[`~data.processors.squad.SquadExample`] using the
|
764 |
+
:method:*~transformers.data.processors.squad.squad_convert_examples_to_features* method.
|
765 |
+
|
766 |
+
Args:
|
767 |
+
input_ids: Indices of input sequence tokens in the vocabulary.
|
768 |
+
attention_mask: Mask to avoid performing attention on padding token indices.
|
769 |
+
token_type_ids: Segment token indices to indicate first and second portions of the inputs.
|
770 |
+
cls_index: the index of the CLS token.
|
771 |
+
p_mask: Mask identifying tokens that can be answers vs. tokens that cannot.
|
772 |
+
Mask with 1 for tokens than cannot be in the answer and 0 for token that can be in an answer
|
773 |
+
example_index: the index of the example
|
774 |
+
unique_id: The unique Feature identifier
|
775 |
+
paragraph_len: The length of the context
|
776 |
+
token_is_max_context:
|
777 |
+
List of booleans identifying which tokens have their maximum context in this feature object. If a token
|
778 |
+
does not have their maximum context in this feature object, it means that another feature object has more
|
779 |
+
information related to that token and should be prioritized over this feature for that token.
|
780 |
+
tokens: list of tokens corresponding to the input ids
|
781 |
+
token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer.
|
782 |
+
start_position: start of the answer token index
|
783 |
+
end_position: end of the answer token index
|
784 |
+
encoding: optionally store the BatchEncoding with the fast-tokenizer alignment methods.
|
785 |
+
"""
|
786 |
+
|
787 |
+
def __init__(
|
788 |
+
self,
|
789 |
+
input_ids,
|
790 |
+
attention_mask,
|
791 |
+
token_type_ids,
|
792 |
+
cls_index,
|
793 |
+
p_mask,
|
794 |
+
example_index,
|
795 |
+
unique_id,
|
796 |
+
paragraph_len,
|
797 |
+
token_is_max_context,
|
798 |
+
tokens,
|
799 |
+
token_to_orig_map,
|
800 |
+
start_position,
|
801 |
+
end_position,
|
802 |
+
is_impossible,
|
803 |
+
qas_id: str = None,
|
804 |
+
encoding: BatchEncoding = None,
|
805 |
+
):
|
806 |
+
self.input_ids = input_ids
|
807 |
+
self.attention_mask = attention_mask
|
808 |
+
self.token_type_ids = token_type_ids
|
809 |
+
self.cls_index = cls_index
|
810 |
+
self.p_mask = p_mask
|
811 |
+
|
812 |
+
self.example_index = example_index
|
813 |
+
self.unique_id = unique_id
|
814 |
+
self.paragraph_len = paragraph_len
|
815 |
+
self.token_is_max_context = token_is_max_context
|
816 |
+
self.tokens = tokens
|
817 |
+
self.token_to_orig_map = token_to_orig_map
|
818 |
+
|
819 |
+
self.start_position = start_position
|
820 |
+
self.end_position = end_position
|
821 |
+
self.is_impossible = is_impossible
|
822 |
+
self.qas_id = qas_id
|
823 |
+
|
824 |
+
self.encoding = encoding
|
825 |
+
|
826 |
+
|
827 |
+
class SquadResult:
|
828 |
+
"""
|
829 |
+
Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset.
|
830 |
+
|
831 |
+
Args:
|
832 |
+
unique_id: The unique identifier corresponding to that example.
|
833 |
+
start_logits: The logits corresponding to the start of the answer
|
834 |
+
end_logits: The logits corresponding to the end of the answer
|
835 |
+
"""
|
836 |
+
|
837 |
+
def __init__(self, unique_id, start_logits, end_logits, start_top_index=None, end_top_index=None, cls_logits=None):
|
838 |
+
self.start_logits = start_logits
|
839 |
+
self.end_logits = end_logits
|
840 |
+
self.unique_id = unique_id
|
841 |
+
|
842 |
+
if start_top_index:
|
843 |
+
self.start_top_index = start_top_index
|
844 |
+
self.end_top_index = end_top_index
|
845 |
+
self.cls_logits = cls_logits
|
transformers_4_35_0/data/processors/utils.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import csv
|
18 |
+
import dataclasses
|
19 |
+
import json
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from typing import List, Optional, Union
|
22 |
+
|
23 |
+
from ...utils import is_tf_available, is_torch_available, logging
|
24 |
+
|
25 |
+
|
26 |
+
logger = logging.get_logger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class InputExample:
|
31 |
+
"""
|
32 |
+
A single training/test example for simple sequence classification.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
guid: Unique id for the example.
|
36 |
+
text_a: string. The untokenized text of the first sequence. For single
|
37 |
+
sequence tasks, only this sequence must be specified.
|
38 |
+
text_b: (Optional) string. The untokenized text of the second sequence.
|
39 |
+
Only must be specified for sequence pair tasks.
|
40 |
+
label: (Optional) string. The label of the example. This should be
|
41 |
+
specified for train and dev examples, but not for test examples.
|
42 |
+
"""
|
43 |
+
|
44 |
+
guid: str
|
45 |
+
text_a: str
|
46 |
+
text_b: Optional[str] = None
|
47 |
+
label: Optional[str] = None
|
48 |
+
|
49 |
+
def to_json_string(self):
|
50 |
+
"""Serializes this instance to a JSON string."""
|
51 |
+
return json.dumps(dataclasses.asdict(self), indent=2) + "\n"
|
52 |
+
|
53 |
+
|
54 |
+
@dataclass(frozen=True)
|
55 |
+
class InputFeatures:
|
56 |
+
"""
|
57 |
+
A single set of features of data. Property names are the same names as the corresponding inputs to a model.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
input_ids: Indices of input sequence tokens in the vocabulary.
|
61 |
+
attention_mask: Mask to avoid performing attention on padding token indices.
|
62 |
+
Mask values selected in `[0, 1]`: Usually `1` for tokens that are NOT MASKED, `0` for MASKED (padded)
|
63 |
+
tokens.
|
64 |
+
token_type_ids: (Optional) Segment token indices to indicate first and second
|
65 |
+
portions of the inputs. Only some models use them.
|
66 |
+
label: (Optional) Label corresponding to the input. Int for classification problems,
|
67 |
+
float for regression problems.
|
68 |
+
"""
|
69 |
+
|
70 |
+
input_ids: List[int]
|
71 |
+
attention_mask: Optional[List[int]] = None
|
72 |
+
token_type_ids: Optional[List[int]] = None
|
73 |
+
label: Optional[Union[int, float]] = None
|
74 |
+
|
75 |
+
def to_json_string(self):
|
76 |
+
"""Serializes this instance to a JSON string."""
|
77 |
+
return json.dumps(dataclasses.asdict(self)) + "\n"
|
78 |
+
|
79 |
+
|
80 |
+
class DataProcessor:
|
81 |
+
"""Base class for data converters for sequence classification data sets."""
|
82 |
+
|
83 |
+
def get_example_from_tensor_dict(self, tensor_dict):
|
84 |
+
"""
|
85 |
+
Gets an example from a dict with tensorflow tensors.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
tensor_dict: Keys and values should match the corresponding Glue
|
89 |
+
tensorflow_dataset examples.
|
90 |
+
"""
|
91 |
+
raise NotImplementedError()
|
92 |
+
|
93 |
+
def get_train_examples(self, data_dir):
|
94 |
+
"""Gets a collection of [`InputExample`] for the train set."""
|
95 |
+
raise NotImplementedError()
|
96 |
+
|
97 |
+
def get_dev_examples(self, data_dir):
|
98 |
+
"""Gets a collection of [`InputExample`] for the dev set."""
|
99 |
+
raise NotImplementedError()
|
100 |
+
|
101 |
+
def get_test_examples(self, data_dir):
|
102 |
+
"""Gets a collection of [`InputExample`] for the test set."""
|
103 |
+
raise NotImplementedError()
|
104 |
+
|
105 |
+
def get_labels(self):
|
106 |
+
"""Gets the list of labels for this data set."""
|
107 |
+
raise NotImplementedError()
|
108 |
+
|
109 |
+
def tfds_map(self, example):
|
110 |
+
"""
|
111 |
+
Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are. This method converts
|
112 |
+
examples to the correct format.
|
113 |
+
"""
|
114 |
+
if len(self.get_labels()) > 1:
|
115 |
+
example.label = self.get_labels()[int(example.label)]
|
116 |
+
return example
|
117 |
+
|
118 |
+
@classmethod
|
119 |
+
def _read_tsv(cls, input_file, quotechar=None):
|
120 |
+
"""Reads a tab separated value file."""
|
121 |
+
with open(input_file, "r", encoding="utf-8-sig") as f:
|
122 |
+
return list(csv.reader(f, delimiter="\t", quotechar=quotechar))
|
123 |
+
|
124 |
+
|
125 |
+
class SingleSentenceClassificationProcessor(DataProcessor):
|
126 |
+
"""Generic processor for a single sentence classification data set."""
|
127 |
+
|
128 |
+
def __init__(self, labels=None, examples=None, mode="classification", verbose=False):
|
129 |
+
self.labels = [] if labels is None else labels
|
130 |
+
self.examples = [] if examples is None else examples
|
131 |
+
self.mode = mode
|
132 |
+
self.verbose = verbose
|
133 |
+
|
134 |
+
def __len__(self):
|
135 |
+
return len(self.examples)
|
136 |
+
|
137 |
+
def __getitem__(self, idx):
|
138 |
+
if isinstance(idx, slice):
|
139 |
+
return SingleSentenceClassificationProcessor(labels=self.labels, examples=self.examples[idx])
|
140 |
+
return self.examples[idx]
|
141 |
+
|
142 |
+
@classmethod
|
143 |
+
def create_from_csv(
|
144 |
+
cls, file_name, split_name="", column_label=0, column_text=1, column_id=None, skip_first_row=False, **kwargs
|
145 |
+
):
|
146 |
+
processor = cls(**kwargs)
|
147 |
+
processor.add_examples_from_csv(
|
148 |
+
file_name,
|
149 |
+
split_name=split_name,
|
150 |
+
column_label=column_label,
|
151 |
+
column_text=column_text,
|
152 |
+
column_id=column_id,
|
153 |
+
skip_first_row=skip_first_row,
|
154 |
+
overwrite_labels=True,
|
155 |
+
overwrite_examples=True,
|
156 |
+
)
|
157 |
+
return processor
|
158 |
+
|
159 |
+
@classmethod
|
160 |
+
def create_from_examples(cls, texts_or_text_and_labels, labels=None, **kwargs):
|
161 |
+
processor = cls(**kwargs)
|
162 |
+
processor.add_examples(texts_or_text_and_labels, labels=labels)
|
163 |
+
return processor
|
164 |
+
|
165 |
+
def add_examples_from_csv(
|
166 |
+
self,
|
167 |
+
file_name,
|
168 |
+
split_name="",
|
169 |
+
column_label=0,
|
170 |
+
column_text=1,
|
171 |
+
column_id=None,
|
172 |
+
skip_first_row=False,
|
173 |
+
overwrite_labels=False,
|
174 |
+
overwrite_examples=False,
|
175 |
+
):
|
176 |
+
lines = self._read_tsv(file_name)
|
177 |
+
if skip_first_row:
|
178 |
+
lines = lines[1:]
|
179 |
+
texts = []
|
180 |
+
labels = []
|
181 |
+
ids = []
|
182 |
+
for i, line in enumerate(lines):
|
183 |
+
texts.append(line[column_text])
|
184 |
+
labels.append(line[column_label])
|
185 |
+
if column_id is not None:
|
186 |
+
ids.append(line[column_id])
|
187 |
+
else:
|
188 |
+
guid = f"{split_name}-{i}" if split_name else str(i)
|
189 |
+
ids.append(guid)
|
190 |
+
|
191 |
+
return self.add_examples(
|
192 |
+
texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples
|
193 |
+
)
|
194 |
+
|
195 |
+
def add_examples(
|
196 |
+
self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
|
197 |
+
):
|
198 |
+
if labels is not None and len(texts_or_text_and_labels) != len(labels):
|
199 |
+
raise ValueError(
|
200 |
+
f"Text and labels have mismatched lengths {len(texts_or_text_and_labels)} and {len(labels)}"
|
201 |
+
)
|
202 |
+
if ids is not None and len(texts_or_text_and_labels) != len(ids):
|
203 |
+
raise ValueError(f"Text and ids have mismatched lengths {len(texts_or_text_and_labels)} and {len(ids)}")
|
204 |
+
if ids is None:
|
205 |
+
ids = [None] * len(texts_or_text_and_labels)
|
206 |
+
if labels is None:
|
207 |
+
labels = [None] * len(texts_or_text_and_labels)
|
208 |
+
examples = []
|
209 |
+
added_labels = set()
|
210 |
+
for text_or_text_and_label, label, guid in zip(texts_or_text_and_labels, labels, ids):
|
211 |
+
if isinstance(text_or_text_and_label, (tuple, list)) and label is None:
|
212 |
+
text, label = text_or_text_and_label
|
213 |
+
else:
|
214 |
+
text = text_or_text_and_label
|
215 |
+
added_labels.add(label)
|
216 |
+
examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label))
|
217 |
+
|
218 |
+
# Update examples
|
219 |
+
if overwrite_examples:
|
220 |
+
self.examples = examples
|
221 |
+
else:
|
222 |
+
self.examples.extend(examples)
|
223 |
+
|
224 |
+
# Update labels
|
225 |
+
if overwrite_labels:
|
226 |
+
self.labels = list(added_labels)
|
227 |
+
else:
|
228 |
+
self.labels = list(set(self.labels).union(added_labels))
|
229 |
+
|
230 |
+
return self.examples
|
231 |
+
|
232 |
+
def get_features(
|
233 |
+
self,
|
234 |
+
tokenizer,
|
235 |
+
max_length=None,
|
236 |
+
pad_on_left=False,
|
237 |
+
pad_token=0,
|
238 |
+
mask_padding_with_zero=True,
|
239 |
+
return_tensors=None,
|
240 |
+
):
|
241 |
+
"""
|
242 |
+
Convert examples in a list of `InputFeatures`
|
243 |
+
|
244 |
+
Args:
|
245 |
+
tokenizer: Instance of a tokenizer that will tokenize the examples
|
246 |
+
max_length: Maximum example length
|
247 |
+
pad_on_left: If set to `True`, the examples will be padded on the left rather than on the right (default)
|
248 |
+
pad_token: Padding token
|
249 |
+
mask_padding_with_zero: If set to `True`, the attention mask will be filled by `1` for actual values
|
250 |
+
and by `0` for padded values. If set to `False`, inverts it (`1` for padded values, `0` for actual
|
251 |
+
values)
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
If the `examples` input is a `tf.data.Dataset`, will return a `tf.data.Dataset` containing the
|
255 |
+
task-specific features. If the input is a list of `InputExamples`, will return a list of task-specific
|
256 |
+
`InputFeatures` which can be fed to the model.
|
257 |
+
|
258 |
+
"""
|
259 |
+
if max_length is None:
|
260 |
+
max_length = tokenizer.max_len
|
261 |
+
|
262 |
+
label_map = {label: i for i, label in enumerate(self.labels)}
|
263 |
+
|
264 |
+
all_input_ids = []
|
265 |
+
for ex_index, example in enumerate(self.examples):
|
266 |
+
if ex_index % 10000 == 0:
|
267 |
+
logger.info(f"Tokenizing example {ex_index}")
|
268 |
+
|
269 |
+
input_ids = tokenizer.encode(
|
270 |
+
example.text_a,
|
271 |
+
add_special_tokens=True,
|
272 |
+
max_length=min(max_length, tokenizer.max_len),
|
273 |
+
)
|
274 |
+
all_input_ids.append(input_ids)
|
275 |
+
|
276 |
+
batch_length = max(len(input_ids) for input_ids in all_input_ids)
|
277 |
+
|
278 |
+
features = []
|
279 |
+
for ex_index, (input_ids, example) in enumerate(zip(all_input_ids, self.examples)):
|
280 |
+
if ex_index % 10000 == 0:
|
281 |
+
logger.info(f"Writing example {ex_index}/{len(self.examples)}")
|
282 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
283 |
+
# tokens are attended to.
|
284 |
+
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
285 |
+
|
286 |
+
# Zero-pad up to the sequence length.
|
287 |
+
padding_length = batch_length - len(input_ids)
|
288 |
+
if pad_on_left:
|
289 |
+
input_ids = ([pad_token] * padding_length) + input_ids
|
290 |
+
attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
|
291 |
+
else:
|
292 |
+
input_ids = input_ids + ([pad_token] * padding_length)
|
293 |
+
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
|
294 |
+
|
295 |
+
if len(input_ids) != batch_length:
|
296 |
+
raise ValueError(f"Error with input length {len(input_ids)} vs {batch_length}")
|
297 |
+
if len(attention_mask) != batch_length:
|
298 |
+
raise ValueError(f"Error with input length {len(attention_mask)} vs {batch_length}")
|
299 |
+
|
300 |
+
if self.mode == "classification":
|
301 |
+
label = label_map[example.label]
|
302 |
+
elif self.mode == "regression":
|
303 |
+
label = float(example.label)
|
304 |
+
else:
|
305 |
+
raise ValueError(self.mode)
|
306 |
+
|
307 |
+
if ex_index < 5 and self.verbose:
|
308 |
+
logger.info("*** Example ***")
|
309 |
+
logger.info(f"guid: {example.guid}")
|
310 |
+
logger.info(f"input_ids: {' '.join([str(x) for x in input_ids])}")
|
311 |
+
logger.info(f"attention_mask: {' '.join([str(x) for x in attention_mask])}")
|
312 |
+
logger.info(f"label: {example.label} (id = {label})")
|
313 |
+
|
314 |
+
features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, label=label))
|
315 |
+
|
316 |
+
if return_tensors is None:
|
317 |
+
return features
|
318 |
+
elif return_tensors == "tf":
|
319 |
+
if not is_tf_available():
|
320 |
+
raise RuntimeError("return_tensors set to 'tf' but TensorFlow 2.0 can't be imported")
|
321 |
+
import tensorflow as tf
|
322 |
+
|
323 |
+
def gen():
|
324 |
+
for ex in features:
|
325 |
+
yield ({"input_ids": ex.input_ids, "attention_mask": ex.attention_mask}, ex.label)
|
326 |
+
|
327 |
+
dataset = tf.data.Dataset.from_generator(
|
328 |
+
gen,
|
329 |
+
({"input_ids": tf.int32, "attention_mask": tf.int32}, tf.int64),
|
330 |
+
({"input_ids": tf.TensorShape([None]), "attention_mask": tf.TensorShape([None])}, tf.TensorShape([])),
|
331 |
+
)
|
332 |
+
return dataset
|
333 |
+
elif return_tensors == "pt":
|
334 |
+
if not is_torch_available():
|
335 |
+
raise RuntimeError("return_tensors set to 'pt' but PyTorch can't be imported")
|
336 |
+
import torch
|
337 |
+
from torch.utils.data import TensorDataset
|
338 |
+
|
339 |
+
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
340 |
+
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
341 |
+
if self.mode == "classification":
|
342 |
+
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
343 |
+
elif self.mode == "regression":
|
344 |
+
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
345 |
+
|
346 |
+
dataset = TensorDataset(all_input_ids, all_attention_mask, all_labels)
|
347 |
+
return dataset
|
348 |
+
else:
|
349 |
+
raise ValueError("return_tensors should be one of 'tf' or 'pt'")
|
transformers_4_35_0/data/processors/xnli.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" XNLI utils (dataset loading and evaluation)"""
|
17 |
+
|
18 |
+
|
19 |
+
import os
|
20 |
+
|
21 |
+
from ...utils import logging
|
22 |
+
from .utils import DataProcessor, InputExample
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
class XnliProcessor(DataProcessor):
|
29 |
+
"""
|
30 |
+
Processor for the XNLI dataset. Adapted from
|
31 |
+
https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/run_classifier.py#L207
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, language, train_language=None):
|
35 |
+
self.language = language
|
36 |
+
self.train_language = train_language
|
37 |
+
|
38 |
+
def get_train_examples(self, data_dir):
|
39 |
+
"""See base class."""
|
40 |
+
lg = self.language if self.train_language is None else self.train_language
|
41 |
+
lines = self._read_tsv(os.path.join(data_dir, f"XNLI-MT-1.0/multinli/multinli.train.{lg}.tsv"))
|
42 |
+
examples = []
|
43 |
+
for i, line in enumerate(lines):
|
44 |
+
if i == 0:
|
45 |
+
continue
|
46 |
+
guid = f"train-{i}"
|
47 |
+
text_a = line[0]
|
48 |
+
text_b = line[1]
|
49 |
+
label = "contradiction" if line[2] == "contradictory" else line[2]
|
50 |
+
if not isinstance(text_a, str):
|
51 |
+
raise ValueError(f"Training input {text_a} is not a string")
|
52 |
+
if not isinstance(text_b, str):
|
53 |
+
raise ValueError(f"Training input {text_b} is not a string")
|
54 |
+
if not isinstance(label, str):
|
55 |
+
raise ValueError(f"Training label {label} is not a string")
|
56 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
57 |
+
return examples
|
58 |
+
|
59 |
+
def get_test_examples(self, data_dir):
|
60 |
+
"""See base class."""
|
61 |
+
lines = self._read_tsv(os.path.join(data_dir, "XNLI-1.0/xnli.test.tsv"))
|
62 |
+
examples = []
|
63 |
+
for i, line in enumerate(lines):
|
64 |
+
if i == 0:
|
65 |
+
continue
|
66 |
+
language = line[0]
|
67 |
+
if language != self.language:
|
68 |
+
continue
|
69 |
+
guid = f"test-{i}"
|
70 |
+
text_a = line[6]
|
71 |
+
text_b = line[7]
|
72 |
+
label = line[1]
|
73 |
+
if not isinstance(text_a, str):
|
74 |
+
raise ValueError(f"Training input {text_a} is not a string")
|
75 |
+
if not isinstance(text_b, str):
|
76 |
+
raise ValueError(f"Training input {text_b} is not a string")
|
77 |
+
if not isinstance(label, str):
|
78 |
+
raise ValueError(f"Training label {label} is not a string")
|
79 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
80 |
+
return examples
|
81 |
+
|
82 |
+
def get_labels(self):
|
83 |
+
"""See base class."""
|
84 |
+
return ["contradiction", "entailment", "neutral"]
|
85 |
+
|
86 |
+
|
87 |
+
xnli_processors = {
|
88 |
+
"xnli": XnliProcessor,
|
89 |
+
}
|
90 |
+
|
91 |
+
xnli_output_modes = {
|
92 |
+
"xnli": "classification",
|
93 |
+
}
|
94 |
+
|
95 |
+
xnli_tasks_num_labels = {
|
96 |
+
"xnli": 3,
|
97 |
+
}
|
transformers_4_35_0/debug_utils.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import collections
|
16 |
+
|
17 |
+
from .utils import ExplicitEnum, is_torch_available, logging
|
18 |
+
|
19 |
+
|
20 |
+
if is_torch_available():
|
21 |
+
import torch
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.get_logger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
class DebugUnderflowOverflow:
|
28 |
+
"""
|
29 |
+
This debug class helps detect and understand where the model starts getting very large or very small, and more
|
30 |
+
importantly `nan` or `inf` weight and activation elements.
|
31 |
+
|
32 |
+
There are 2 working modes:
|
33 |
+
|
34 |
+
1. Underflow/overflow detection (default)
|
35 |
+
2. Specific batch absolute min/max tracing without detection
|
36 |
+
|
37 |
+
Mode 1: Underflow/overflow detection
|
38 |
+
|
39 |
+
To activate the underflow/overflow detection, initialize the object with the model :
|
40 |
+
|
41 |
+
```python
|
42 |
+
debug_overflow = DebugUnderflowOverflow(model)
|
43 |
+
```
|
44 |
+
|
45 |
+
then run the training as normal and if `nan` or `inf` gets detected in at least one of the weight, input or output
|
46 |
+
elements this module will throw an exception and will print `max_frames_to_save` frames that lead to this event,
|
47 |
+
each frame reporting
|
48 |
+
|
49 |
+
1. the fully qualified module name plus the class name whose `forward` was run
|
50 |
+
2. the absolute min and max value of all elements for each module weights, and the inputs and output
|
51 |
+
|
52 |
+
For example, here is the header and the last few frames in detection report for `google/mt5-small` run in fp16
|
53 |
+
mixed precision :
|
54 |
+
|
55 |
+
```
|
56 |
+
Detected inf/nan during batch_number=0
|
57 |
+
Last 21 forward frames:
|
58 |
+
abs min abs max metadata
|
59 |
+
[...]
|
60 |
+
encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
|
61 |
+
2.17e-07 4.50e+00 weight
|
62 |
+
1.79e-06 4.65e+00 input[0]
|
63 |
+
2.68e-06 3.70e+01 output
|
64 |
+
encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
|
65 |
+
8.08e-07 2.66e+01 weight
|
66 |
+
1.79e-06 4.65e+00 input[0]
|
67 |
+
1.27e-04 2.37e+02 output
|
68 |
+
encoder.block.2.layer.1.DenseReluDense.wo Linear
|
69 |
+
1.01e-06 6.44e+00 weight
|
70 |
+
0.00e+00 9.74e+03 input[0]
|
71 |
+
3.18e-04 6.27e+04 output
|
72 |
+
encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
|
73 |
+
1.79e-06 4.65e+00 input[0]
|
74 |
+
3.18e-04 6.27e+04 output
|
75 |
+
encoder.block.2.layer.1.dropout Dropout
|
76 |
+
3.18e-04 6.27e+04 input[0]
|
77 |
+
0.00e+00 inf output
|
78 |
+
```
|
79 |
+
|
80 |
+
You can see here, that `T5DenseGatedGeluDense.forward` resulted in output activations, whose absolute max value was
|
81 |
+
around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have `Dropout` which
|
82 |
+
renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than
|
83 |
+
64K, and we get an overlow.
|
84 |
+
|
85 |
+
As you can see it's the previous frames that we need to look into when the numbers start going into very large for
|
86 |
+
fp16 numbers.
|
87 |
+
|
88 |
+
The tracking is done in a forward hook, which gets invoked immediately after `forward` has completed.
|
89 |
+
|
90 |
+
By default the last 21 frames are printed. You can change the default to adjust for your needs. For example :
|
91 |
+
|
92 |
+
```python
|
93 |
+
debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)
|
94 |
+
```
|
95 |
+
|
96 |
+
To validate that you have set up this debugging feature correctly, and you intend to use it in a training that
|
97 |
+
may take hours to complete, first run it with normal tracing enabled for one of a few batches as explained in
|
98 |
+
the next section.
|
99 |
+
|
100 |
+
|
101 |
+
Mode 2. Specific batch absolute min/max tracing without detection
|
102 |
+
|
103 |
+
The second work mode is per-batch tracing with the underflow/overflow detection feature turned off.
|
104 |
+
|
105 |
+
Let's say you want to watch the absolute min and max values for all the ingredients of each `forward` call of a
|
106 |
+
given batch, and only do that for batches 1 and 3. Then you instantiate this class as :
|
107 |
+
|
108 |
+
```python
|
109 |
+
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3])
|
110 |
+
```
|
111 |
+
|
112 |
+
And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed.
|
113 |
+
|
114 |
+
This is helpful if you know that the program starts misbehaving after a certain batch number, so you can
|
115 |
+
fast-forward right to that area.
|
116 |
+
|
117 |
+
|
118 |
+
Early stopping:
|
119 |
+
|
120 |
+
You can also specify the batch number after which to stop the training, with :
|
121 |
+
|
122 |
+
```python
|
123 |
+
debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3], abort_after_batch_num=3)
|
124 |
+
```
|
125 |
+
|
126 |
+
This feature is mainly useful in the tracing mode, but you can use it for any mode.
|
127 |
+
|
128 |
+
|
129 |
+
**Performance**:
|
130 |
+
|
131 |
+
As this module measures absolute `min`/``max` of each weight of the model on every forward it'll slow the training
|
132 |
+
down. Therefore remember to turn it off once the debugging needs have been met.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
model (`nn.Module`):
|
136 |
+
The model to debug.
|
137 |
+
max_frames_to_save (`int`, *optional*, defaults to 21):
|
138 |
+
How many frames back to record
|
139 |
+
trace_batch_nums(`List[int]`, *optional*, defaults to `[]`):
|
140 |
+
Which batch numbers to trace (turns detection off)
|
141 |
+
abort_after_batch_num (`int``, *optional*):
|
142 |
+
Whether to abort after a certain batch number has finished
|
143 |
+
"""
|
144 |
+
|
145 |
+
def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None):
|
146 |
+
self.model = model
|
147 |
+
self.trace_batch_nums = trace_batch_nums
|
148 |
+
self.abort_after_batch_num = abort_after_batch_num
|
149 |
+
|
150 |
+
# keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence
|
151 |
+
self.frames = collections.deque([], max_frames_to_save)
|
152 |
+
self.frame = []
|
153 |
+
self.batch_number = 0
|
154 |
+
self.total_calls = 0
|
155 |
+
self.detected_overflow = False
|
156 |
+
self.prefix = " "
|
157 |
+
|
158 |
+
self.analyse_model()
|
159 |
+
|
160 |
+
self.register_forward_hook()
|
161 |
+
|
162 |
+
def save_frame(self, frame=None):
|
163 |
+
if frame is not None:
|
164 |
+
self.expand_frame(frame)
|
165 |
+
self.frames.append("\n".join(self.frame))
|
166 |
+
self.frame = [] # start a new frame
|
167 |
+
|
168 |
+
def expand_frame(self, line):
|
169 |
+
self.frame.append(line)
|
170 |
+
|
171 |
+
def trace_frames(self):
|
172 |
+
print("\n".join(self.frames))
|
173 |
+
self.frames = []
|
174 |
+
|
175 |
+
def reset_saved_frames(self):
|
176 |
+
self.frames = []
|
177 |
+
|
178 |
+
def dump_saved_frames(self):
|
179 |
+
print(f"\nDetected inf/nan during batch_number={self.batch_number}")
|
180 |
+
print(f"Last {len(self.frames)} forward frames:")
|
181 |
+
print(f"{'abs min':8} {'abs max':8} metadata")
|
182 |
+
print("\n".join(self.frames))
|
183 |
+
print("\n\n")
|
184 |
+
self.frames = []
|
185 |
+
|
186 |
+
def analyse_model(self):
|
187 |
+
# extract the fully qualified module names, to be able to report at run time. e.g.:
|
188 |
+
# encoder.block.2.layer.0.SelfAttention.o
|
189 |
+
#
|
190 |
+
# for shared weights only the first shared module name will be registered
|
191 |
+
self.module_names = {m: name for name, m in self.model.named_modules()}
|
192 |
+
# self.longest_module_name = max(len(v) for v in self.module_names.values())
|
193 |
+
|
194 |
+
def analyse_variable(self, var, ctx):
|
195 |
+
if torch.is_tensor(var):
|
196 |
+
self.expand_frame(get_abs_min_max(var, ctx))
|
197 |
+
if detect_overflow(var, ctx):
|
198 |
+
self.detected_overflow = True
|
199 |
+
elif var is None:
|
200 |
+
self.expand_frame(f"{'None':>17} {ctx}")
|
201 |
+
else:
|
202 |
+
self.expand_frame(f"{'not a tensor':>17} {ctx}")
|
203 |
+
|
204 |
+
def batch_start_frame(self):
|
205 |
+
self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***")
|
206 |
+
self.expand_frame(f"{'abs min':8} {'abs max':8} metadata")
|
207 |
+
|
208 |
+
def batch_end_frame(self):
|
209 |
+
self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number-1} ***\n\n")
|
210 |
+
|
211 |
+
def create_frame(self, module, input, output):
|
212 |
+
self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}")
|
213 |
+
|
214 |
+
# params
|
215 |
+
for name, p in module.named_parameters(recurse=False):
|
216 |
+
self.analyse_variable(p, name)
|
217 |
+
|
218 |
+
# inputs
|
219 |
+
if isinstance(input, tuple):
|
220 |
+
for i, x in enumerate(input):
|
221 |
+
self.analyse_variable(x, f"input[{i}]")
|
222 |
+
else:
|
223 |
+
self.analyse_variable(input, "input")
|
224 |
+
|
225 |
+
# outputs
|
226 |
+
if isinstance(output, tuple):
|
227 |
+
for i, x in enumerate(output):
|
228 |
+
# possibly a tuple of tuples
|
229 |
+
if isinstance(x, tuple):
|
230 |
+
for j, y in enumerate(x):
|
231 |
+
self.analyse_variable(y, f"output[{i}][{j}]")
|
232 |
+
else:
|
233 |
+
self.analyse_variable(x, f"output[{i}]")
|
234 |
+
else:
|
235 |
+
self.analyse_variable(output, "output")
|
236 |
+
|
237 |
+
self.save_frame()
|
238 |
+
|
239 |
+
def register_forward_hook(self):
|
240 |
+
self.model.apply(self._register_forward_hook)
|
241 |
+
|
242 |
+
def _register_forward_hook(self, module):
|
243 |
+
module.register_forward_hook(self.forward_hook)
|
244 |
+
|
245 |
+
def forward_hook(self, module, input, output):
|
246 |
+
# - input is a tuple of packed inputs (could be non-Tensors)
|
247 |
+
# - output could be a Tensor or a tuple of Tensors and non-Tensors
|
248 |
+
|
249 |
+
last_frame_of_batch = False
|
250 |
+
|
251 |
+
trace_mode = True if self.batch_number in self.trace_batch_nums else False
|
252 |
+
if trace_mode:
|
253 |
+
self.reset_saved_frames()
|
254 |
+
|
255 |
+
if self.total_calls == 0:
|
256 |
+
self.batch_start_frame()
|
257 |
+
self.total_calls += 1
|
258 |
+
|
259 |
+
# count batch numbers - the very first forward hook of the batch will be called when the
|
260 |
+
# batch completes - i.e. it gets called very last - we know this batch has finished
|
261 |
+
if module == self.model:
|
262 |
+
self.batch_number += 1
|
263 |
+
last_frame_of_batch = True
|
264 |
+
|
265 |
+
self.create_frame(module, input, output)
|
266 |
+
|
267 |
+
# if last_frame_of_batch:
|
268 |
+
# self.batch_end_frame()
|
269 |
+
|
270 |
+
if trace_mode:
|
271 |
+
self.trace_frames()
|
272 |
+
|
273 |
+
if last_frame_of_batch:
|
274 |
+
self.batch_start_frame()
|
275 |
+
|
276 |
+
if self.detected_overflow and not trace_mode:
|
277 |
+
self.dump_saved_frames()
|
278 |
+
|
279 |
+
# now we can abort, as it's pointless to continue running
|
280 |
+
raise ValueError(
|
281 |
+
"DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. "
|
282 |
+
"Please scroll up above this traceback to see the activation values prior to this event."
|
283 |
+
)
|
284 |
+
|
285 |
+
# abort after certain batch if requested to do so
|
286 |
+
if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num:
|
287 |
+
raise ValueError(
|
288 |
+
f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to"
|
289 |
+
f" `abort_after_batch_num={self.abort_after_batch_num}` arg"
|
290 |
+
)
|
291 |
+
|
292 |
+
|
293 |
+
def get_abs_min_max(var, ctx):
|
294 |
+
abs_var = var.abs()
|
295 |
+
return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}"
|
296 |
+
|
297 |
+
|
298 |
+
def detect_overflow(var, ctx):
|
299 |
+
"""
|
300 |
+
Report whether the tensor contains any `nan` or `inf` entries.
|
301 |
+
|
302 |
+
This is useful for detecting overflows/underflows and best to call right after the function that did some math that
|
303 |
+
modified the tensor in question.
|
304 |
+
|
305 |
+
This function contains a few other helper features that you can enable and tweak directly if you want to track
|
306 |
+
various other things.
|
307 |
+
|
308 |
+
Args:
|
309 |
+
var: the tensor variable to check
|
310 |
+
ctx: the message to print as a context
|
311 |
+
|
312 |
+
Return:
|
313 |
+
`True` if `inf` or `nan` was detected, `False` otherwise
|
314 |
+
"""
|
315 |
+
detected = False
|
316 |
+
if torch.isnan(var).any().item():
|
317 |
+
detected = True
|
318 |
+
print(f"{ctx} has nans")
|
319 |
+
if torch.isinf(var).any().item():
|
320 |
+
detected = True
|
321 |
+
print(f"{ctx} has infs")
|
322 |
+
|
323 |
+
# if needed to monitor large elements can enable the following
|
324 |
+
if 0: # and detected:
|
325 |
+
n100 = var[torch.ge(var.abs(), 100)]
|
326 |
+
if n100.numel() > 0:
|
327 |
+
print(f"{ctx}: n100={n100.numel()}")
|
328 |
+
n1000 = var[torch.ge(var.abs(), 1000)]
|
329 |
+
if n1000.numel() > 0:
|
330 |
+
print(f"{ctx}: n1000={n1000.numel()}")
|
331 |
+
n10000 = var[torch.ge(var.abs(), 10000)]
|
332 |
+
if n10000.numel() > 0:
|
333 |
+
print(f"{ctx}: n10000={n10000.numel()}")
|
334 |
+
|
335 |
+
if 0:
|
336 |
+
print(f"min={var.min():9.2e} max={var.max():9.2e}")
|
337 |
+
|
338 |
+
if 0:
|
339 |
+
print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})")
|
340 |
+
|
341 |
+
return detected
|
342 |
+
|
343 |
+
|
344 |
+
class DebugOption(ExplicitEnum):
|
345 |
+
UNDERFLOW_OVERFLOW = "underflow_overflow"
|
346 |
+
TPU_METRICS_DEBUG = "tpu_metrics_debug"
|
transformers_4_35_0/deepspeed.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Integration with Deepspeed - kept for backward compatiblity, if you plan to make any edit, make sure to modify the file
|
16 |
+
in `integrations/deepspeed` instead.
|
17 |
+
|
18 |
+
Check: https://github.com/huggingface/transformers/pull/25599
|
19 |
+
"""
|
20 |
+
import warnings
|
21 |
+
|
22 |
+
|
23 |
+
warnings.warn(
|
24 |
+
"transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations",
|
25 |
+
FutureWarning,
|
26 |
+
)
|
27 |
+
|
28 |
+
# Backward compatibility imports, to make sure all those objects can be found in integrations/deepspeed
|
29 |
+
from .integrations.deepspeed import ( # noqa
|
30 |
+
HfDeepSpeedConfig,
|
31 |
+
HfTrainerDeepSpeedConfig,
|
32 |
+
deepspeed_config,
|
33 |
+
deepspeed_init,
|
34 |
+
deepspeed_load_checkpoint,
|
35 |
+
deepspeed_optim_sched,
|
36 |
+
is_deepspeed_available,
|
37 |
+
is_deepspeed_zero3_enabled,
|
38 |
+
set_hf_deepspeed_config,
|
39 |
+
unset_hf_deepspeed_config,
|
40 |
+
)
|
transformers_4_35_0/dependency_versions_check.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .dependency_versions_table import deps
|
16 |
+
from .utils.versions import require_version, require_version_core
|
17 |
+
|
18 |
+
|
19 |
+
# define which module versions we always want to check at run time
|
20 |
+
# (usually the ones defined in `install_requires` in setup.py)
|
21 |
+
#
|
22 |
+
# order specific notes:
|
23 |
+
# - tqdm must be checked before tokenizers
|
24 |
+
|
25 |
+
pkgs_to_check_at_runtime = [
|
26 |
+
"python",
|
27 |
+
"tqdm",
|
28 |
+
"regex",
|
29 |
+
"requests",
|
30 |
+
"packaging",
|
31 |
+
"filelock",
|
32 |
+
"numpy",
|
33 |
+
"tokenizers",
|
34 |
+
"huggingface-hub",
|
35 |
+
"safetensors",
|
36 |
+
"accelerate",
|
37 |
+
"pyyaml",
|
38 |
+
]
|
39 |
+
|
40 |
+
for pkg in pkgs_to_check_at_runtime:
|
41 |
+
if pkg in deps:
|
42 |
+
if pkg == "tokenizers":
|
43 |
+
# must be loaded here, or else tqdm check may fail
|
44 |
+
from .utils import is_tokenizers_available
|
45 |
+
|
46 |
+
if not is_tokenizers_available():
|
47 |
+
continue # not required, check version only if installed
|
48 |
+
elif pkg == "accelerate":
|
49 |
+
# must be loaded here, or else tqdm check may fail
|
50 |
+
from .utils import is_accelerate_available
|
51 |
+
|
52 |
+
# Maybe switch to is_torch_available in the future here so that Accelerate is hard dep of
|
53 |
+
# Transformers with PyTorch
|
54 |
+
if not is_accelerate_available():
|
55 |
+
continue # not required, check version only if installed
|
56 |
+
|
57 |
+
require_version_core(deps[pkg])
|
58 |
+
else:
|
59 |
+
raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
|
60 |
+
|
61 |
+
|
62 |
+
def dep_version_check(pkg, hint=None):
|
63 |
+
require_version(deps[pkg], hint)
|
transformers_4_35_0/dependency_versions_table.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# THIS FILE HAS BEEN AUTOGENERATED. To update:
|
2 |
+
# 1. modify the `_deps` dict in setup.py
|
3 |
+
# 2. run `make deps_table_update``
|
4 |
+
deps = {
|
5 |
+
"Pillow": "Pillow<10.0.0",
|
6 |
+
"accelerate": "accelerate>=0.20.3",
|
7 |
+
"av": "av==9.2.0",
|
8 |
+
"beautifulsoup4": "beautifulsoup4",
|
9 |
+
"black": "black~=23.1",
|
10 |
+
"codecarbon": "codecarbon==1.2.0",
|
11 |
+
"cookiecutter": "cookiecutter==1.7.3",
|
12 |
+
"dataclasses": "dataclasses",
|
13 |
+
"datasets": "datasets!=2.5.0",
|
14 |
+
"decord": "decord==0.6.0",
|
15 |
+
"deepspeed": "deepspeed>=0.9.3",
|
16 |
+
"diffusers": "diffusers",
|
17 |
+
"dill": "dill<0.3.5",
|
18 |
+
"evaluate": "evaluate>=0.2.0",
|
19 |
+
"faiss-cpu": "faiss-cpu",
|
20 |
+
"fastapi": "fastapi",
|
21 |
+
"filelock": "filelock",
|
22 |
+
"flax": "flax>=0.4.1,<=0.7.0",
|
23 |
+
"ftfy": "ftfy",
|
24 |
+
"fugashi": "fugashi>=1.0",
|
25 |
+
"GitPython": "GitPython<3.1.19",
|
26 |
+
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
27 |
+
"huggingface-hub": "huggingface-hub>=0.16.4,<1.0",
|
28 |
+
"importlib_metadata": "importlib_metadata",
|
29 |
+
"ipadic": "ipadic>=1.0.0,<2.0",
|
30 |
+
"isort": "isort>=5.5.4",
|
31 |
+
"jax": "jax>=0.4.1,<=0.4.13",
|
32 |
+
"jaxlib": "jaxlib>=0.4.1,<=0.4.13",
|
33 |
+
"jieba": "jieba",
|
34 |
+
"kenlm": "kenlm",
|
35 |
+
"keras-nlp": "keras-nlp>=0.3.1",
|
36 |
+
"librosa": "librosa",
|
37 |
+
"nltk": "nltk",
|
38 |
+
"natten": "natten>=0.14.6",
|
39 |
+
"numpy": "numpy>=1.17",
|
40 |
+
"onnxconverter-common": "onnxconverter-common",
|
41 |
+
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
|
42 |
+
"onnxruntime": "onnxruntime>=1.4.0",
|
43 |
+
"opencv-python": "opencv-python",
|
44 |
+
"optuna": "optuna",
|
45 |
+
"optax": "optax>=0.0.8,<=0.1.4",
|
46 |
+
"packaging": "packaging>=20.0",
|
47 |
+
"parameterized": "parameterized",
|
48 |
+
"phonemizer": "phonemizer",
|
49 |
+
"protobuf": "protobuf",
|
50 |
+
"psutil": "psutil",
|
51 |
+
"pyyaml": "pyyaml>=5.1",
|
52 |
+
"pydantic": "pydantic<2",
|
53 |
+
"pytest": "pytest>=7.2.0",
|
54 |
+
"pytest-timeout": "pytest-timeout",
|
55 |
+
"pytest-xdist": "pytest-xdist",
|
56 |
+
"python": "python>=3.8.0",
|
57 |
+
"ray[tune]": "ray[tune]",
|
58 |
+
"regex": "regex!=2019.12.17",
|
59 |
+
"requests": "requests",
|
60 |
+
"rhoknp": "rhoknp>=1.1.0,<1.3.1",
|
61 |
+
"rjieba": "rjieba",
|
62 |
+
"rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
|
63 |
+
"ruff": "ruff>=0.0.241,<=0.0.259",
|
64 |
+
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
|
65 |
+
"sacremoses": "sacremoses",
|
66 |
+
"safetensors": "safetensors>=0.3.1",
|
67 |
+
"sagemaker": "sagemaker>=2.31.0",
|
68 |
+
"scikit-learn": "scikit-learn",
|
69 |
+
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
70 |
+
"sigopt": "sigopt",
|
71 |
+
"starlette": "starlette",
|
72 |
+
"sudachipy": "sudachipy>=0.6.6",
|
73 |
+
"sudachidict_core": "sudachidict_core>=20220729",
|
74 |
+
"tensorflow-cpu": "tensorflow-cpu>=2.6,<2.15",
|
75 |
+
"tensorflow": "tensorflow>=2.6,<2.15",
|
76 |
+
"tensorflow-text": "tensorflow-text<2.15",
|
77 |
+
"tf2onnx": "tf2onnx",
|
78 |
+
"timeout-decorator": "timeout-decorator",
|
79 |
+
"timm": "timm",
|
80 |
+
"tokenizers": "tokenizers>=0.14,<0.15",
|
81 |
+
"torch": "torch>=1.10,!=1.12.0",
|
82 |
+
"torchaudio": "torchaudio",
|
83 |
+
"torchvision": "torchvision",
|
84 |
+
"pyctcdecode": "pyctcdecode>=0.4.0",
|
85 |
+
"tqdm": "tqdm>=4.27",
|
86 |
+
"unidic": "unidic>=1.0.2",
|
87 |
+
"unidic_lite": "unidic_lite>=1.0.7",
|
88 |
+
"urllib3": "urllib3<2.0.0",
|
89 |
+
"uvicorn": "uvicorn",
|
90 |
+
}
|
transformers_4_35_0/dynamic_module_utils.py
ADDED
@@ -0,0 +1,624 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Utilities to dynamically load objects from the Hub."""
|
16 |
+
import filecmp
|
17 |
+
import importlib
|
18 |
+
import os
|
19 |
+
import re
|
20 |
+
import shutil
|
21 |
+
import signal
|
22 |
+
import sys
|
23 |
+
import typing
|
24 |
+
import warnings
|
25 |
+
from pathlib import Path
|
26 |
+
from typing import Any, Dict, List, Optional, Union
|
27 |
+
|
28 |
+
from .utils import (
|
29 |
+
HF_MODULES_CACHE,
|
30 |
+
TRANSFORMERS_DYNAMIC_MODULE_NAME,
|
31 |
+
cached_file,
|
32 |
+
extract_commit_hash,
|
33 |
+
is_offline_mode,
|
34 |
+
logging,
|
35 |
+
try_to_load_from_cache,
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
40 |
+
|
41 |
+
|
42 |
+
def init_hf_modules():
|
43 |
+
"""
|
44 |
+
Creates the cache directory for modules with an init, and adds it to the Python path.
|
45 |
+
"""
|
46 |
+
# This function has already been executed if HF_MODULES_CACHE already is in the Python path.
|
47 |
+
if HF_MODULES_CACHE in sys.path:
|
48 |
+
return
|
49 |
+
|
50 |
+
sys.path.append(HF_MODULES_CACHE)
|
51 |
+
os.makedirs(HF_MODULES_CACHE, exist_ok=True)
|
52 |
+
init_path = Path(HF_MODULES_CACHE) / "__init__.py"
|
53 |
+
if not init_path.exists():
|
54 |
+
init_path.touch()
|
55 |
+
importlib.invalidate_caches()
|
56 |
+
|
57 |
+
|
58 |
+
def create_dynamic_module(name: Union[str, os.PathLike]):
|
59 |
+
"""
|
60 |
+
Creates a dynamic module in the cache directory for modules.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
name (`str` or `os.PathLike`):
|
64 |
+
The name of the dynamic module to create.
|
65 |
+
"""
|
66 |
+
init_hf_modules()
|
67 |
+
dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve()
|
68 |
+
# If the parent module does not exist yet, recursively create it.
|
69 |
+
if not dynamic_module_path.parent.exists():
|
70 |
+
create_dynamic_module(dynamic_module_path.parent)
|
71 |
+
os.makedirs(dynamic_module_path, exist_ok=True)
|
72 |
+
init_path = dynamic_module_path / "__init__.py"
|
73 |
+
if not init_path.exists():
|
74 |
+
init_path.touch()
|
75 |
+
# It is extremely important to invalidate the cache when we change stuff in those modules, or users end up
|
76 |
+
# with errors about module that do not exist. Same for all other `invalidate_caches` in this file.
|
77 |
+
importlib.invalidate_caches()
|
78 |
+
|
79 |
+
|
80 |
+
def get_relative_imports(module_file: Union[str, os.PathLike]) -> List[str]:
|
81 |
+
"""
|
82 |
+
Get the list of modules that are relatively imported in a module file.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
module_file (`str` or `os.PathLike`): The module file to inspect.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
`List[str]`: The list of relative imports in the module.
|
89 |
+
"""
|
90 |
+
with open(module_file, "r", encoding="utf-8") as f:
|
91 |
+
content = f.read()
|
92 |
+
|
93 |
+
# Imports of the form `import .xxx`
|
94 |
+
relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
|
95 |
+
# Imports of the form `from .xxx import yyy`
|
96 |
+
relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
|
97 |
+
# Unique-ify
|
98 |
+
return list(set(relative_imports))
|
99 |
+
|
100 |
+
|
101 |
+
def get_relative_import_files(module_file: Union[str, os.PathLike]) -> List[str]:
|
102 |
+
"""
|
103 |
+
Get the list of all files that are needed for a given module. Note that this function recurses through the relative
|
104 |
+
imports (if a imports b and b imports c, it will return module files for b and c).
|
105 |
+
|
106 |
+
Args:
|
107 |
+
module_file (`str` or `os.PathLike`): The module file to inspect.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
`List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
|
111 |
+
of module files a given module needs.
|
112 |
+
"""
|
113 |
+
no_change = False
|
114 |
+
files_to_check = [module_file]
|
115 |
+
all_relative_imports = []
|
116 |
+
|
117 |
+
# Let's recurse through all relative imports
|
118 |
+
while not no_change:
|
119 |
+
new_imports = []
|
120 |
+
for f in files_to_check:
|
121 |
+
new_imports.extend(get_relative_imports(f))
|
122 |
+
|
123 |
+
module_path = Path(module_file).parent
|
124 |
+
new_import_files = [str(module_path / m) for m in new_imports]
|
125 |
+
new_import_files = [f for f in new_import_files if f not in all_relative_imports]
|
126 |
+
files_to_check = [f"{f}.py" for f in new_import_files]
|
127 |
+
|
128 |
+
no_change = len(new_import_files) == 0
|
129 |
+
all_relative_imports.extend(files_to_check)
|
130 |
+
|
131 |
+
return all_relative_imports
|
132 |
+
|
133 |
+
|
134 |
+
def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
|
135 |
+
"""
|
136 |
+
Extracts all the libraries (not relative imports this time) that are imported in a file.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
filename (`str` or `os.PathLike`): The module file to inspect.
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
`List[str]`: The list of all packages required to use the input module.
|
143 |
+
"""
|
144 |
+
with open(filename, "r", encoding="utf-8") as f:
|
145 |
+
content = f.read()
|
146 |
+
|
147 |
+
# filter out try/except block so in custom code we can have try/except imports
|
148 |
+
content = re.sub(r"\s*try\s*:\s*.*?\s*except\s*.*?:", "", content, flags=re.MULTILINE | re.DOTALL)
|
149 |
+
|
150 |
+
# Imports of the form `import xxx`
|
151 |
+
imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
|
152 |
+
# Imports of the form `from xxx import yyy`
|
153 |
+
imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
|
154 |
+
# Only keep the top-level module
|
155 |
+
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
|
156 |
+
return list(set(imports))
|
157 |
+
|
158 |
+
|
159 |
+
def check_imports(filename: Union[str, os.PathLike]) -> List[str]:
|
160 |
+
"""
|
161 |
+
Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a
|
162 |
+
library is missing.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
filename (`str` or `os.PathLike`): The module file to check.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
`List[str]`: The list of relative imports in the file.
|
169 |
+
"""
|
170 |
+
imports = get_imports(filename)
|
171 |
+
missing_packages = []
|
172 |
+
for imp in imports:
|
173 |
+
try:
|
174 |
+
importlib.import_module(imp)
|
175 |
+
except ImportError:
|
176 |
+
missing_packages.append(imp)
|
177 |
+
|
178 |
+
if len(missing_packages) > 0:
|
179 |
+
raise ImportError(
|
180 |
+
"This modeling file requires the following packages that were not found in your environment: "
|
181 |
+
f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
|
182 |
+
)
|
183 |
+
|
184 |
+
return get_relative_imports(filename)
|
185 |
+
|
186 |
+
|
187 |
+
def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) -> typing.Type:
|
188 |
+
"""
|
189 |
+
Import a module on the cache directory for modules and extract a class from it.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
class_name (`str`): The name of the class to import.
|
193 |
+
module_path (`str` or `os.PathLike`): The path to the module to import.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
`typing.Type`: The class looked for.
|
197 |
+
"""
|
198 |
+
module_path = module_path.replace(os.path.sep, ".")
|
199 |
+
module = importlib.import_module(module_path)
|
200 |
+
return getattr(module, class_name)
|
201 |
+
|
202 |
+
|
203 |
+
def get_cached_module_file(
|
204 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
205 |
+
module_file: str,
|
206 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
207 |
+
force_download: bool = False,
|
208 |
+
resume_download: bool = False,
|
209 |
+
proxies: Optional[Dict[str, str]] = None,
|
210 |
+
token: Optional[Union[bool, str]] = None,
|
211 |
+
revision: Optional[str] = None,
|
212 |
+
local_files_only: bool = False,
|
213 |
+
repo_type: Optional[str] = None,
|
214 |
+
_commit_hash: Optional[str] = None,
|
215 |
+
**deprecated_kwargs,
|
216 |
+
) -> str:
|
217 |
+
"""
|
218 |
+
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
|
219 |
+
Transformers module.
|
220 |
+
|
221 |
+
Args:
|
222 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
223 |
+
This can be either:
|
224 |
+
|
225 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
226 |
+
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
|
227 |
+
under a user or organization name, like `dbmdz/bert-base-german-cased`.
|
228 |
+
- a path to a *directory* containing a configuration file saved using the
|
229 |
+
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
230 |
+
|
231 |
+
module_file (`str`):
|
232 |
+
The name of the module file containing the class to look for.
|
233 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
234 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
235 |
+
cache should not be used.
|
236 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
237 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
238 |
+
exist.
|
239 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
240 |
+
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
|
241 |
+
proxies (`Dict[str, str]`, *optional*):
|
242 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
243 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
244 |
+
token (`str` or *bool*, *optional*):
|
245 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
246 |
+
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
247 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
248 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
249 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
250 |
+
identifier allowed by git.
|
251 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
252 |
+
If `True`, will only try to load the tokenizer configuration from local files.
|
253 |
+
repo_type (`str`, *optional*):
|
254 |
+
Specify the repo type (useful when downloading from a space for instance).
|
255 |
+
|
256 |
+
<Tip>
|
257 |
+
|
258 |
+
Passing `token=True` is required when you want to use a private model.
|
259 |
+
|
260 |
+
</Tip>
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
`str`: The path to the module inside the cache.
|
264 |
+
"""
|
265 |
+
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
|
266 |
+
if use_auth_token is not None:
|
267 |
+
warnings.warn(
|
268 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
269 |
+
)
|
270 |
+
if token is not None:
|
271 |
+
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
272 |
+
token = use_auth_token
|
273 |
+
|
274 |
+
if is_offline_mode() and not local_files_only:
|
275 |
+
logger.info("Offline mode: forcing local_files_only=True")
|
276 |
+
local_files_only = True
|
277 |
+
|
278 |
+
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
|
279 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
280 |
+
is_local = os.path.isdir(pretrained_model_name_or_path)
|
281 |
+
if is_local:
|
282 |
+
submodule = os.path.basename(pretrained_model_name_or_path)
|
283 |
+
else:
|
284 |
+
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
|
285 |
+
cached_module = try_to_load_from_cache(
|
286 |
+
pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
|
287 |
+
)
|
288 |
+
|
289 |
+
new_files = []
|
290 |
+
try:
|
291 |
+
# Load from URL or cache if already cached
|
292 |
+
resolved_module_file = cached_file(
|
293 |
+
pretrained_model_name_or_path,
|
294 |
+
module_file,
|
295 |
+
cache_dir=cache_dir,
|
296 |
+
force_download=force_download,
|
297 |
+
proxies=proxies,
|
298 |
+
resume_download=resume_download,
|
299 |
+
local_files_only=local_files_only,
|
300 |
+
token=token,
|
301 |
+
revision=revision,
|
302 |
+
repo_type=repo_type,
|
303 |
+
_commit_hash=_commit_hash,
|
304 |
+
)
|
305 |
+
if not is_local and cached_module != resolved_module_file:
|
306 |
+
new_files.append(module_file)
|
307 |
+
|
308 |
+
except EnvironmentError:
|
309 |
+
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
|
310 |
+
raise
|
311 |
+
|
312 |
+
# Check we have all the requirements in our environment
|
313 |
+
modules_needed = check_imports(resolved_module_file)
|
314 |
+
|
315 |
+
# Now we move the module inside our cached dynamic modules.
|
316 |
+
full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
|
317 |
+
create_dynamic_module(full_submodule)
|
318 |
+
submodule_path = Path(HF_MODULES_CACHE) / full_submodule
|
319 |
+
if submodule == os.path.basename(pretrained_model_name_or_path):
|
320 |
+
# We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or
|
321 |
+
# has changed since last copy.
|
322 |
+
if not (submodule_path / module_file).exists() or not filecmp.cmp(
|
323 |
+
resolved_module_file, str(submodule_path / module_file)
|
324 |
+
):
|
325 |
+
shutil.copy(resolved_module_file, submodule_path / module_file)
|
326 |
+
importlib.invalidate_caches()
|
327 |
+
for module_needed in modules_needed:
|
328 |
+
module_needed = f"{module_needed}.py"
|
329 |
+
module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed)
|
330 |
+
if not (submodule_path / module_needed).exists() or not filecmp.cmp(
|
331 |
+
module_needed_file, str(submodule_path / module_needed)
|
332 |
+
):
|
333 |
+
shutil.copy(module_needed_file, submodule_path / module_needed)
|
334 |
+
importlib.invalidate_caches()
|
335 |
+
else:
|
336 |
+
# Get the commit hash
|
337 |
+
commit_hash = extract_commit_hash(resolved_module_file, _commit_hash)
|
338 |
+
|
339 |
+
# The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
|
340 |
+
# benefit of versioning.
|
341 |
+
submodule_path = submodule_path / commit_hash
|
342 |
+
full_submodule = full_submodule + os.path.sep + commit_hash
|
343 |
+
create_dynamic_module(full_submodule)
|
344 |
+
|
345 |
+
if not (submodule_path / module_file).exists():
|
346 |
+
shutil.copy(resolved_module_file, submodule_path / module_file)
|
347 |
+
importlib.invalidate_caches()
|
348 |
+
# Make sure we also have every file with relative
|
349 |
+
for module_needed in modules_needed:
|
350 |
+
if not (submodule_path / f"{module_needed}.py").exists():
|
351 |
+
get_cached_module_file(
|
352 |
+
pretrained_model_name_or_path,
|
353 |
+
f"{module_needed}.py",
|
354 |
+
cache_dir=cache_dir,
|
355 |
+
force_download=force_download,
|
356 |
+
resume_download=resume_download,
|
357 |
+
proxies=proxies,
|
358 |
+
token=token,
|
359 |
+
revision=revision,
|
360 |
+
local_files_only=local_files_only,
|
361 |
+
_commit_hash=commit_hash,
|
362 |
+
)
|
363 |
+
new_files.append(f"{module_needed}.py")
|
364 |
+
|
365 |
+
if len(new_files) > 0 and revision is None:
|
366 |
+
new_files = "\n".join([f"- {f}" for f in new_files])
|
367 |
+
repo_type_str = "" if repo_type is None else f"{repo_type}s/"
|
368 |
+
url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
|
369 |
+
logger.warning(
|
370 |
+
f"A new version of the following files was downloaded from {url}:\n{new_files}"
|
371 |
+
"\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
|
372 |
+
"versions of the code file, you can pin a revision."
|
373 |
+
)
|
374 |
+
|
375 |
+
return os.path.join(full_submodule, module_file)
|
376 |
+
|
377 |
+
|
378 |
+
def get_class_from_dynamic_module(
|
379 |
+
class_reference: str,
|
380 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
381 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
382 |
+
force_download: bool = False,
|
383 |
+
resume_download: bool = False,
|
384 |
+
proxies: Optional[Dict[str, str]] = None,
|
385 |
+
token: Optional[Union[bool, str]] = None,
|
386 |
+
revision: Optional[str] = None,
|
387 |
+
local_files_only: bool = False,
|
388 |
+
repo_type: Optional[str] = None,
|
389 |
+
code_revision: Optional[str] = None,
|
390 |
+
**kwargs,
|
391 |
+
) -> typing.Type:
|
392 |
+
"""
|
393 |
+
Extracts a class from a module file, present in the local folder or repository of a model.
|
394 |
+
|
395 |
+
<Tip warning={true}>
|
396 |
+
|
397 |
+
Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
|
398 |
+
therefore only be called on trusted repos.
|
399 |
+
|
400 |
+
</Tip>
|
401 |
+
|
402 |
+
Args:
|
403 |
+
class_reference (`str`):
|
404 |
+
The full name of the class to load, including its module and optionally its repo.
|
405 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
406 |
+
This can be either:
|
407 |
+
|
408 |
+
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
409 |
+
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
|
410 |
+
under a user or organization name, like `dbmdz/bert-base-german-cased`.
|
411 |
+
- a path to a *directory* containing a configuration file saved using the
|
412 |
+
[`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
|
413 |
+
|
414 |
+
This is used when `class_reference` does not specify another repo.
|
415 |
+
module_file (`str`):
|
416 |
+
The name of the module file containing the class to look for.
|
417 |
+
class_name (`str`):
|
418 |
+
The name of the class to import in the module.
|
419 |
+
cache_dir (`str` or `os.PathLike`, *optional*):
|
420 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
421 |
+
cache should not be used.
|
422 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
423 |
+
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
424 |
+
exist.
|
425 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
426 |
+
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
|
427 |
+
proxies (`Dict[str, str]`, *optional*):
|
428 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
429 |
+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
430 |
+
token (`str` or `bool`, *optional*):
|
431 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
432 |
+
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
433 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
434 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
435 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
436 |
+
identifier allowed by git.
|
437 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
438 |
+
If `True`, will only try to load the tokenizer configuration from local files.
|
439 |
+
repo_type (`str`, *optional*):
|
440 |
+
Specify the repo type (useful when downloading from a space for instance).
|
441 |
+
code_revision (`str`, *optional*, defaults to `"main"`):
|
442 |
+
The specific revision to use for the code on the Hub, if the code leaves in a different repository than the
|
443 |
+
rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for
|
444 |
+
storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
|
445 |
+
|
446 |
+
<Tip>
|
447 |
+
|
448 |
+
Passing `token=True` is required when you want to use a private model.
|
449 |
+
|
450 |
+
</Tip>
|
451 |
+
|
452 |
+
Returns:
|
453 |
+
`typing.Type`: The class, dynamically imported from the module.
|
454 |
+
|
455 |
+
Examples:
|
456 |
+
|
457 |
+
```python
|
458 |
+
# Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
|
459 |
+
# module.
|
460 |
+
cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model")
|
461 |
+
|
462 |
+
# Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this
|
463 |
+
# module.
|
464 |
+
cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model")
|
465 |
+
```"""
|
466 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
467 |
+
if use_auth_token is not None:
|
468 |
+
warnings.warn(
|
469 |
+
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
470 |
+
)
|
471 |
+
if token is not None:
|
472 |
+
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
473 |
+
token = use_auth_token
|
474 |
+
|
475 |
+
# Catch the name of the repo if it's specified in `class_reference`
|
476 |
+
if "--" in class_reference:
|
477 |
+
repo_id, class_reference = class_reference.split("--")
|
478 |
+
else:
|
479 |
+
repo_id = pretrained_model_name_or_path
|
480 |
+
module_file, class_name = class_reference.split(".")
|
481 |
+
|
482 |
+
if code_revision is None and pretrained_model_name_or_path == repo_id:
|
483 |
+
code_revision = revision
|
484 |
+
# And lastly we get the class inside our newly created module
|
485 |
+
final_module = get_cached_module_file(
|
486 |
+
repo_id,
|
487 |
+
module_file + ".py",
|
488 |
+
cache_dir=cache_dir,
|
489 |
+
force_download=force_download,
|
490 |
+
resume_download=resume_download,
|
491 |
+
proxies=proxies,
|
492 |
+
token=token,
|
493 |
+
revision=code_revision,
|
494 |
+
local_files_only=local_files_only,
|
495 |
+
repo_type=repo_type,
|
496 |
+
)
|
497 |
+
return get_class_in_module(class_name, final_module.replace(".py", ""))
|
498 |
+
|
499 |
+
|
500 |
+
def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]:
|
501 |
+
"""
|
502 |
+
Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
|
503 |
+
adds the proper fields in a config.
|
504 |
+
|
505 |
+
Args:
|
506 |
+
obj (`Any`): The object for which to save the module files.
|
507 |
+
folder (`str` or `os.PathLike`): The folder where to save.
|
508 |
+
config (`PretrainedConfig` or dictionary, `optional`):
|
509 |
+
A config in which to register the auto_map corresponding to this custom object.
|
510 |
+
|
511 |
+
Returns:
|
512 |
+
`List[str]`: The list of files saved.
|
513 |
+
"""
|
514 |
+
if obj.__module__ == "__main__":
|
515 |
+
logger.warning(
|
516 |
+
f"We can't save the code defining {obj} in {folder} as it's been defined in __main__. You should put "
|
517 |
+
"this code in a separate module so we can include it in the saved folder and make it easier to share via "
|
518 |
+
"the Hub."
|
519 |
+
)
|
520 |
+
return
|
521 |
+
|
522 |
+
def _set_auto_map_in_config(_config):
|
523 |
+
module_name = obj.__class__.__module__
|
524 |
+
last_module = module_name.split(".")[-1]
|
525 |
+
full_name = f"{last_module}.{obj.__class__.__name__}"
|
526 |
+
# Special handling for tokenizers
|
527 |
+
if "Tokenizer" in full_name:
|
528 |
+
slow_tokenizer_class = None
|
529 |
+
fast_tokenizer_class = None
|
530 |
+
if obj.__class__.__name__.endswith("Fast"):
|
531 |
+
# Fast tokenizer: we have the fast tokenizer class and we may have the slow one has an attribute.
|
532 |
+
fast_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
|
533 |
+
if getattr(obj, "slow_tokenizer_class", None) is not None:
|
534 |
+
slow_tokenizer = getattr(obj, "slow_tokenizer_class")
|
535 |
+
slow_tok_module_name = slow_tokenizer.__module__
|
536 |
+
last_slow_tok_module = slow_tok_module_name.split(".")[-1]
|
537 |
+
slow_tokenizer_class = f"{last_slow_tok_module}.{slow_tokenizer.__name__}"
|
538 |
+
else:
|
539 |
+
# Slow tokenizer: no way to have the fast class
|
540 |
+
slow_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
|
541 |
+
|
542 |
+
full_name = (slow_tokenizer_class, fast_tokenizer_class)
|
543 |
+
|
544 |
+
if isinstance(_config, dict):
|
545 |
+
auto_map = _config.get("auto_map", {})
|
546 |
+
auto_map[obj._auto_class] = full_name
|
547 |
+
_config["auto_map"] = auto_map
|
548 |
+
elif getattr(_config, "auto_map", None) is not None:
|
549 |
+
_config.auto_map[obj._auto_class] = full_name
|
550 |
+
else:
|
551 |
+
_config.auto_map = {obj._auto_class: full_name}
|
552 |
+
|
553 |
+
# Add object class to the config auto_map
|
554 |
+
if isinstance(config, (list, tuple)):
|
555 |
+
for cfg in config:
|
556 |
+
_set_auto_map_in_config(cfg)
|
557 |
+
elif config is not None:
|
558 |
+
_set_auto_map_in_config(config)
|
559 |
+
|
560 |
+
result = []
|
561 |
+
# Copy module file to the output folder.
|
562 |
+
object_file = sys.modules[obj.__module__].__file__
|
563 |
+
dest_file = Path(folder) / (Path(object_file).name)
|
564 |
+
shutil.copy(object_file, dest_file)
|
565 |
+
result.append(dest_file)
|
566 |
+
|
567 |
+
# Gather all relative imports recursively and make sure they are copied as well.
|
568 |
+
for needed_file in get_relative_import_files(object_file):
|
569 |
+
dest_file = Path(folder) / (Path(needed_file).name)
|
570 |
+
shutil.copy(needed_file, dest_file)
|
571 |
+
result.append(dest_file)
|
572 |
+
|
573 |
+
return result
|
574 |
+
|
575 |
+
|
576 |
+
def _raise_timeout_error(signum, frame):
|
577 |
+
raise ValueError(
|
578 |
+
"Loading this model requires you to execute custom code contained in the model repository on your local"
|
579 |
+
"machine. Please set the option `trust_remote_code=True` to permit loading of this model."
|
580 |
+
)
|
581 |
+
|
582 |
+
|
583 |
+
TIME_OUT_REMOTE_CODE = 15
|
584 |
+
|
585 |
+
|
586 |
+
def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
|
587 |
+
if trust_remote_code is None:
|
588 |
+
if has_local_code:
|
589 |
+
trust_remote_code = False
|
590 |
+
elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
|
591 |
+
try:
|
592 |
+
signal.signal(signal.SIGALRM, _raise_timeout_error)
|
593 |
+
signal.alarm(TIME_OUT_REMOTE_CODE)
|
594 |
+
while trust_remote_code is None:
|
595 |
+
answer = input(
|
596 |
+
f"The repository for {model_name} contains custom code which must be executed to correctly"
|
597 |
+
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
598 |
+
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
|
599 |
+
f"Do you wish to run the custom code? [y/N] "
|
600 |
+
)
|
601 |
+
if answer.lower() in ["yes", "y", "1"]:
|
602 |
+
trust_remote_code = True
|
603 |
+
elif answer.lower() in ["no", "n", "0", ""]:
|
604 |
+
trust_remote_code = False
|
605 |
+
signal.alarm(0)
|
606 |
+
except Exception:
|
607 |
+
# OS which does not support signal.SIGALRM
|
608 |
+
raise ValueError(
|
609 |
+
f"The repository for {model_name} contains custom code which must be executed to correctly"
|
610 |
+
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
611 |
+
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
612 |
+
)
|
613 |
+
elif has_remote_code:
|
614 |
+
# For the CI which puts the timeout at 0
|
615 |
+
_raise_timeout_error(None, None)
|
616 |
+
|
617 |
+
if has_remote_code and not has_local_code and not trust_remote_code:
|
618 |
+
raise ValueError(
|
619 |
+
f"Loading {model_name} requires you to execute the configuration file in that"
|
620 |
+
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
|
621 |
+
" set the option `trust_remote_code=True` to remove this error."
|
622 |
+
)
|
623 |
+
|
624 |
+
return trust_remote_code
|