Create trainer_seq2seq.py
Browse files- trainer_seq2seq.py +246 -0
trainer_seq2seq.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
from torch.utils.data import Dataset
|
20 |
+
|
21 |
+
from .deepspeed import is_deepspeed_zero3_enabled
|
22 |
+
from .trainer import Trainer
|
23 |
+
from .trainer_utils import PredictionOutput
|
24 |
+
from .utils import logging
|
25 |
+
|
26 |
+
|
27 |
+
logger = logging.get_logger(__name__)
|
28 |
+
|
29 |
+
|
30 |
+
class Seq2SeqTrainer(Trainer):
|
31 |
+
def evaluate(
|
32 |
+
self,
|
33 |
+
eval_dataset: Optional[Dataset] = None,
|
34 |
+
ignore_keys: Optional[List[str]] = None,
|
35 |
+
metric_key_prefix: str = "eval",
|
36 |
+
**gen_kwargs,
|
37 |
+
) -> Dict[str, float]:
|
38 |
+
"""
|
39 |
+
Run evaluation and returns metrics.
|
40 |
+
|
41 |
+
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
|
42 |
+
(pass it to the init `compute_metrics` argument).
|
43 |
+
|
44 |
+
You can also subclass and override this method to inject custom behavior.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
eval_dataset (`Dataset`, *optional*):
|
48 |
+
Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
|
49 |
+
not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
|
50 |
+
method.
|
51 |
+
ignore_keys (`List[str]`, *optional*):
|
52 |
+
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
53 |
+
gathering predictions.
|
54 |
+
metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
|
55 |
+
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
|
56 |
+
"eval_bleu" if the prefix is `"eval"` (default)
|
57 |
+
max_length (`int`, *optional*):
|
58 |
+
The maximum target length to use when predicting with the generate method.
|
59 |
+
num_beams (`int`, *optional*):
|
60 |
+
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
|
61 |
+
beam search.
|
62 |
+
gen_kwargs:
|
63 |
+
Additional `generate` specific kwargs.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
|
67 |
+
dictionary also contains the epoch number which comes from the training state.
|
68 |
+
"""
|
69 |
+
|
70 |
+
gen_kwargs = gen_kwargs.copy()
|
71 |
+
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
|
72 |
+
gen_kwargs["max_length"] = self.args.generation_max_length
|
73 |
+
gen_kwargs["num_beams"] = (
|
74 |
+
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
|
75 |
+
)
|
76 |
+
self._gen_kwargs = gen_kwargs
|
77 |
+
|
78 |
+
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
79 |
+
|
80 |
+
def predict(
|
81 |
+
self,
|
82 |
+
test_dataset: Dataset,
|
83 |
+
ignore_keys: Optional[List[str]] = None,
|
84 |
+
metric_key_prefix: str = "test",
|
85 |
+
**gen_kwargs,
|
86 |
+
) -> PredictionOutput:
|
87 |
+
"""
|
88 |
+
Run prediction and returns predictions and potential metrics.
|
89 |
+
|
90 |
+
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
|
91 |
+
will also return metrics, like in `evaluate()`.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
test_dataset (`Dataset`):
|
95 |
+
Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
|
96 |
+
`model.forward()` method are automatically removed. Has to implement the method `__len__`
|
97 |
+
ignore_keys (`List[str]`, *optional*):
|
98 |
+
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
99 |
+
gathering predictions.
|
100 |
+
metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
|
101 |
+
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
|
102 |
+
"eval_bleu" if the prefix is `"eval"` (default)
|
103 |
+
max_length (`int`, *optional*):
|
104 |
+
The maximum target length to use when predicting with the generate method.
|
105 |
+
num_beams (`int`, *optional*):
|
106 |
+
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
|
107 |
+
beam search.
|
108 |
+
gen_kwargs:
|
109 |
+
Additional `generate` specific kwargs.
|
110 |
+
|
111 |
+
<Tip>
|
112 |
+
|
113 |
+
If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
|
114 |
+
padding in a token classification task) the predictions will be padded (on the right) to allow for
|
115 |
+
concatenation into one array. The padding index is -100.
|
116 |
+
|
117 |
+
</Tip>
|
118 |
+
|
119 |
+
Returns: *NamedTuple* A namedtuple with the following keys:
|
120 |
+
|
121 |
+
- predictions (`np.ndarray`): The predictions on `test_dataset`.
|
122 |
+
- label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
|
123 |
+
- metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
|
124 |
+
labels).
|
125 |
+
"""
|
126 |
+
|
127 |
+
gen_kwargs = gen_kwargs.copy()
|
128 |
+
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
|
129 |
+
gen_kwargs["max_length"] = self.args.generation_max_length
|
130 |
+
gen_kwargs["num_beams"] = (
|
131 |
+
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
|
132 |
+
)
|
133 |
+
self._gen_kwargs = gen_kwargs
|
134 |
+
|
135 |
+
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
136 |
+
|
137 |
+
def prediction_step(
|
138 |
+
self,
|
139 |
+
model: nn.Module,
|
140 |
+
inputs: Dict[str, Union[torch.Tensor, Any]],
|
141 |
+
prediction_loss_only: bool,
|
142 |
+
ignore_keys: Optional[List[str]] = None,
|
143 |
+
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
144 |
+
"""
|
145 |
+
Perform an evaluation step on `model` using `inputs`.
|
146 |
+
|
147 |
+
Subclass and override to inject custom behavior.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
model (`nn.Module`):
|
151 |
+
The model to evaluate.
|
152 |
+
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
|
153 |
+
The inputs and targets of the model.
|
154 |
+
|
155 |
+
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
156 |
+
argument `labels`. Check your model's documentation for all accepted arguments.
|
157 |
+
prediction_loss_only (`bool`):
|
158 |
+
Whether or not to return the loss only.
|
159 |
+
|
160 |
+
Return:
|
161 |
+
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
|
162 |
+
labels (each being optional).
|
163 |
+
"""
|
164 |
+
|
165 |
+
if not self.args.predict_with_generate or prediction_loss_only:
|
166 |
+
return super().prediction_step(
|
167 |
+
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
168 |
+
)
|
169 |
+
|
170 |
+
has_labels = "labels" in inputs
|
171 |
+
inputs = self._prepare_inputs(inputs)
|
172 |
+
|
173 |
+
# XXX: adapt synced_gpus for fairscale as well
|
174 |
+
gen_kwargs = self._gen_kwargs.copy()
|
175 |
+
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
|
176 |
+
gen_kwargs["max_length"] = self.model.config.max_length
|
177 |
+
gen_kwargs["num_beams"] = (
|
178 |
+
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
|
179 |
+
)
|
180 |
+
default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
|
181 |
+
gen_kwargs["synced_gpus"] = (
|
182 |
+
gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
|
183 |
+
)
|
184 |
+
|
185 |
+
# TODO (Joao): the following line is needed to keep a consistent result on SQUAD. Ideally, we should not block
|
186 |
+
# users from preparing a dataset with `decoder_input_ids`.
|
187 |
+
inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"}
|
188 |
+
generated_tokens = self.model.generate(**inputs, **gen_kwargs)
|
189 |
+
|
190 |
+
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
191 |
+
# TODO: remove this hack when the legacy code that initializes generation_config from a model config is
|
192 |
+
# removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
|
193 |
+
if self.model.generation_config._from_model_config:
|
194 |
+
self.model.generation_config._from_model_config = False
|
195 |
+
# in case the batch is shorter than max length, the output should be padded
|
196 |
+
if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
|
197 |
+
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
|
198 |
+
elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
|
199 |
+
gen_kwargs["max_new_tokens"] + 1
|
200 |
+
):
|
201 |
+
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)
|
202 |
+
|
203 |
+
with torch.no_grad():
|
204 |
+
if has_labels:
|
205 |
+
with self.compute_loss_context_manager():
|
206 |
+
outputs = model(**inputs)
|
207 |
+
if self.label_smoother is not None:
|
208 |
+
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
|
209 |
+
else:
|
210 |
+
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
|
211 |
+
else:
|
212 |
+
loss = None
|
213 |
+
|
214 |
+
if self.args.prediction_loss_only:
|
215 |
+
return (loss, None, None)
|
216 |
+
|
217 |
+
if has_labels:
|
218 |
+
labels = inputs["labels"]
|
219 |
+
if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
|
220 |
+
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
|
221 |
+
elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
|
222 |
+
gen_kwargs["max_new_tokens"] + 1
|
223 |
+
):
|
224 |
+
labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
|
225 |
+
else:
|
226 |
+
labels = None
|
227 |
+
|
228 |
+
return (loss, generated_tokens, labels)
|
229 |
+
|
230 |
+
def _pad_tensors_to_max_len(self, tensor, max_length):
|
231 |
+
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
|
232 |
+
# If PAD token is not defined at least EOS token has to be defined
|
233 |
+
pad_token_id = (
|
234 |
+
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
|
235 |
+
)
|
236 |
+
else:
|
237 |
+
if self.model.config.pad_token_id is not None:
|
238 |
+
pad_token_id = self.model.config.pad_token_id
|
239 |
+
else:
|
240 |
+
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
|
241 |
+
|
242 |
+
padded_tensor = pad_token_id * torch.ones(
|
243 |
+
(tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
|
244 |
+
)
|
245 |
+
padded_tensor[:, : tensor.shape[-1]] = tensor
|
246 |
+
return padded_tensor
|