Cletrason commited on
Commit
7c5f10e
1 Parent(s): 6395c7e

Create trainer_seq2seq.py

Browse files
Files changed (1) hide show
  1. trainer_seq2seq.py +246 -0
trainer_seq2seq.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.utils.data import Dataset
20
+
21
+ from .deepspeed import is_deepspeed_zero3_enabled
22
+ from .trainer import Trainer
23
+ from .trainer_utils import PredictionOutput
24
+ from .utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class Seq2SeqTrainer(Trainer):
31
+ def evaluate(
32
+ self,
33
+ eval_dataset: Optional[Dataset] = None,
34
+ ignore_keys: Optional[List[str]] = None,
35
+ metric_key_prefix: str = "eval",
36
+ **gen_kwargs,
37
+ ) -> Dict[str, float]:
38
+ """
39
+ Run evaluation and returns metrics.
40
+
41
+ The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
42
+ (pass it to the init `compute_metrics` argument).
43
+
44
+ You can also subclass and override this method to inject custom behavior.
45
+
46
+ Args:
47
+ eval_dataset (`Dataset`, *optional*):
48
+ Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
49
+ not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
50
+ method.
51
+ ignore_keys (`List[str]`, *optional*):
52
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
53
+ gathering predictions.
54
+ metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
55
+ An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
56
+ "eval_bleu" if the prefix is `"eval"` (default)
57
+ max_length (`int`, *optional*):
58
+ The maximum target length to use when predicting with the generate method.
59
+ num_beams (`int`, *optional*):
60
+ Number of beams for beam search that will be used when predicting with the generate method. 1 means no
61
+ beam search.
62
+ gen_kwargs:
63
+ Additional `generate` specific kwargs.
64
+
65
+ Returns:
66
+ A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
67
+ dictionary also contains the epoch number which comes from the training state.
68
+ """
69
+
70
+ gen_kwargs = gen_kwargs.copy()
71
+ if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
72
+ gen_kwargs["max_length"] = self.args.generation_max_length
73
+ gen_kwargs["num_beams"] = (
74
+ gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
75
+ )
76
+ self._gen_kwargs = gen_kwargs
77
+
78
+ return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
79
+
80
+ def predict(
81
+ self,
82
+ test_dataset: Dataset,
83
+ ignore_keys: Optional[List[str]] = None,
84
+ metric_key_prefix: str = "test",
85
+ **gen_kwargs,
86
+ ) -> PredictionOutput:
87
+ """
88
+ Run prediction and returns predictions and potential metrics.
89
+
90
+ Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
91
+ will also return metrics, like in `evaluate()`.
92
+
93
+ Args:
94
+ test_dataset (`Dataset`):
95
+ Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
96
+ `model.forward()` method are automatically removed. Has to implement the method `__len__`
97
+ ignore_keys (`List[str]`, *optional*):
98
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
99
+ gathering predictions.
100
+ metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
101
+ An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
102
+ "eval_bleu" if the prefix is `"eval"` (default)
103
+ max_length (`int`, *optional*):
104
+ The maximum target length to use when predicting with the generate method.
105
+ num_beams (`int`, *optional*):
106
+ Number of beams for beam search that will be used when predicting with the generate method. 1 means no
107
+ beam search.
108
+ gen_kwargs:
109
+ Additional `generate` specific kwargs.
110
+
111
+ <Tip>
112
+
113
+ If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
114
+ padding in a token classification task) the predictions will be padded (on the right) to allow for
115
+ concatenation into one array. The padding index is -100.
116
+
117
+ </Tip>
118
+
119
+ Returns: *NamedTuple* A namedtuple with the following keys:
120
+
121
+ - predictions (`np.ndarray`): The predictions on `test_dataset`.
122
+ - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
123
+ - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
124
+ labels).
125
+ """
126
+
127
+ gen_kwargs = gen_kwargs.copy()
128
+ if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
129
+ gen_kwargs["max_length"] = self.args.generation_max_length
130
+ gen_kwargs["num_beams"] = (
131
+ gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
132
+ )
133
+ self._gen_kwargs = gen_kwargs
134
+
135
+ return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
136
+
137
+ def prediction_step(
138
+ self,
139
+ model: nn.Module,
140
+ inputs: Dict[str, Union[torch.Tensor, Any]],
141
+ prediction_loss_only: bool,
142
+ ignore_keys: Optional[List[str]] = None,
143
+ ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
144
+ """
145
+ Perform an evaluation step on `model` using `inputs`.
146
+
147
+ Subclass and override to inject custom behavior.
148
+
149
+ Args:
150
+ model (`nn.Module`):
151
+ The model to evaluate.
152
+ inputs (`Dict[str, Union[torch.Tensor, Any]]`):
153
+ The inputs and targets of the model.
154
+
155
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
156
+ argument `labels`. Check your model's documentation for all accepted arguments.
157
+ prediction_loss_only (`bool`):
158
+ Whether or not to return the loss only.
159
+
160
+ Return:
161
+ Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
162
+ labels (each being optional).
163
+ """
164
+
165
+ if not self.args.predict_with_generate or prediction_loss_only:
166
+ return super().prediction_step(
167
+ model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
168
+ )
169
+
170
+ has_labels = "labels" in inputs
171
+ inputs = self._prepare_inputs(inputs)
172
+
173
+ # XXX: adapt synced_gpus for fairscale as well
174
+ gen_kwargs = self._gen_kwargs.copy()
175
+ if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
176
+ gen_kwargs["max_length"] = self.model.config.max_length
177
+ gen_kwargs["num_beams"] = (
178
+ gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
179
+ )
180
+ default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
181
+ gen_kwargs["synced_gpus"] = (
182
+ gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
183
+ )
184
+
185
+ # TODO (Joao): the following line is needed to keep a consistent result on SQUAD. Ideally, we should not block
186
+ # users from preparing a dataset with `decoder_input_ids`.
187
+ inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"}
188
+ generated_tokens = self.model.generate(**inputs, **gen_kwargs)
189
+
190
+ # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
191
+ # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
192
+ # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
193
+ if self.model.generation_config._from_model_config:
194
+ self.model.generation_config._from_model_config = False
195
+ # in case the batch is shorter than max length, the output should be padded
196
+ if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
197
+ generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
198
+ elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
199
+ gen_kwargs["max_new_tokens"] + 1
200
+ ):
201
+ generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)
202
+
203
+ with torch.no_grad():
204
+ if has_labels:
205
+ with self.compute_loss_context_manager():
206
+ outputs = model(**inputs)
207
+ if self.label_smoother is not None:
208
+ loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
209
+ else:
210
+ loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
211
+ else:
212
+ loss = None
213
+
214
+ if self.args.prediction_loss_only:
215
+ return (loss, None, None)
216
+
217
+ if has_labels:
218
+ labels = inputs["labels"]
219
+ if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
220
+ labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
221
+ elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
222
+ gen_kwargs["max_new_tokens"] + 1
223
+ ):
224
+ labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
225
+ else:
226
+ labels = None
227
+
228
+ return (loss, generated_tokens, labels)
229
+
230
+ def _pad_tensors_to_max_len(self, tensor, max_length):
231
+ if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
232
+ # If PAD token is not defined at least EOS token has to be defined
233
+ pad_token_id = (
234
+ self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
235
+ )
236
+ else:
237
+ if self.model.config.pad_token_id is not None:
238
+ pad_token_id = self.model.config.pad_token_id
239
+ else:
240
+ raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
241
+
242
+ padded_tensor = pad_token_id * torch.ones(
243
+ (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
244
+ )
245
+ padded_tensor[:, : tensor.shape[-1]] = tensor
246
+ return padded_tensor