Spaces:
Running
Running
feat: use custom TrainingArguments
Browse files- dev/seq2seq/run_seq2seq_flax.py +114 -29
dev/seq2seq/run_seq2seq_flax.py
CHANGED
@@ -44,7 +44,6 @@ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_ke
|
|
44 |
from transformers import (
|
45 |
AutoTokenizer,
|
46 |
HfArgumentParser,
|
47 |
-
TrainingArguments,
|
48 |
)
|
49 |
from transformers.models.bart.modeling_flax_bart import BartConfig
|
50 |
|
@@ -93,12 +92,6 @@ class ModelArguments:
|
|
93 |
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
94 |
},
|
95 |
)
|
96 |
-
from_checkpoint: Optional[str] = field(
|
97 |
-
default=None,
|
98 |
-
metadata={
|
99 |
-
"help": "Loads a pretrained wandb checkpoint. Use artifact reference."
|
100 |
-
},
|
101 |
-
)
|
102 |
|
103 |
|
104 |
@dataclass
|
@@ -143,10 +136,6 @@ class DataTrainingArguments:
|
|
143 |
"than this will be truncated, sequences shorter will be padded."
|
144 |
},
|
145 |
)
|
146 |
-
use_decay: bool = field(
|
147 |
-
default=False,
|
148 |
-
metadata={"help": "Whether to use decay in the learning rate scheduler."},
|
149 |
-
)
|
150 |
max_train_samples: Optional[int] = field(
|
151 |
default=None,
|
152 |
metadata={
|
@@ -173,18 +162,116 @@ class DataTrainingArguments:
|
|
173 |
"help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
|
174 |
},
|
175 |
)
|
176 |
-
|
177 |
-
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
)
|
180 |
log_model: bool = field(
|
181 |
default=False,
|
182 |
-
metadata={"help": "Log
|
183 |
)
|
184 |
|
185 |
-
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
|
190 |
class TrainState(train_state.TrainState):
|
@@ -291,10 +378,7 @@ def wandb_log(metrics, step=None, prefix=None):
|
|
291 |
|
292 |
|
293 |
def main():
|
294 |
-
# See all possible arguments
|
295 |
-
# or by passing the --help flag to this script.
|
296 |
-
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
297 |
-
|
298 |
parser = HfArgumentParser(
|
299 |
(ModelArguments, DataTrainingArguments, TrainingArguments)
|
300 |
)
|
@@ -358,8 +442,8 @@ def main():
|
|
358 |
config=parser.parse_args(),
|
359 |
)
|
360 |
|
361 |
-
if
|
362 |
-
artifact = wandb.run.use_artifact(
|
363 |
artifact_dir = artifact.download()
|
364 |
|
365 |
# load model
|
@@ -574,7 +658,7 @@ def main():
|
|
574 |
learning_rate_fn = create_learning_rate_fn(
|
575 |
training_args.warmup_steps,
|
576 |
training_args.learning_rate,
|
577 |
-
|
578 |
num_train_steps,
|
579 |
)
|
580 |
|
@@ -607,6 +691,7 @@ def main():
|
|
607 |
learning_rate=learning_rate_fn,
|
608 |
weight_decay_rate=training_args.weight_decay,
|
609 |
weight_decay_mask=decay_mask_fn,
|
|
|
610 |
)
|
611 |
else:
|
612 |
optimizer = optax.adamw(
|
@@ -631,7 +716,7 @@ def main():
|
|
631 |
tx=optimizer,
|
632 |
dropout_rng=dropout_rng,
|
633 |
)
|
634 |
-
if
|
635 |
# restore optimizer state and other parameters
|
636 |
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
637 |
state = state.restore_state(artifact_dir)
|
@@ -771,7 +856,7 @@ def main():
|
|
771 |
with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
|
772 |
f.write(to_bytes(opt_state))
|
773 |
state_dict = {
|
774 |
-
k: unreplicate(getattr(state, k))
|
775 |
for k in ["step", "epoch", "train_time", "train_samples"]
|
776 |
}
|
777 |
with (Path(training_args.output_dir) / "training_state.json").open(
|
@@ -783,7 +868,7 @@ def main():
|
|
783 |
)
|
784 |
|
785 |
# save to W&B
|
786 |
-
if
|
787 |
# save some space
|
788 |
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
789 |
c.cleanup(wandb.util.from_human_size("10GB"))
|
@@ -866,7 +951,7 @@ def main():
|
|
866 |
)
|
867 |
step = unreplicate(state.step)
|
868 |
|
869 |
-
if step %
|
870 |
# log metrics
|
871 |
wandb_log(unreplicate(train_metric), step=step, prefix="train")
|
872 |
# log state parameters
|
|
|
44 |
from transformers import (
|
45 |
AutoTokenizer,
|
46 |
HfArgumentParser,
|
|
|
47 |
)
|
48 |
from transformers.models.bart.modeling_flax_bart import BartConfig
|
49 |
|
|
|
92 |
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
93 |
},
|
94 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
|
97 |
@dataclass
|
|
|
136 |
"than this will be truncated, sequences shorter will be padded."
|
137 |
},
|
138 |
)
|
|
|
|
|
|
|
|
|
139 |
max_train_samples: Optional[int] = field(
|
140 |
default=None,
|
141 |
metadata={
|
|
|
162 |
"help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
|
163 |
},
|
164 |
)
|
165 |
+
|
166 |
+
def __post_init__(self):
|
167 |
+
if self.dataset_repo_or_path is None:
|
168 |
+
raise ValueError("Need a dataset repository or path.")
|
169 |
+
|
170 |
+
|
171 |
+
@dataclass
|
172 |
+
class TrainingArguments:
|
173 |
+
"""
|
174 |
+
Arguments pertaining to training parameters.
|
175 |
+
"""
|
176 |
+
|
177 |
+
output_dir: str = field(
|
178 |
+
metadata={
|
179 |
+
"help": "The output directory where the model predictions and checkpoints will be written."
|
180 |
+
},
|
181 |
+
)
|
182 |
+
overwrite_output_dir: bool = field(
|
183 |
+
default=False,
|
184 |
+
metadata={
|
185 |
+
"help": (
|
186 |
+
"Overwrite the content of the output directory. "
|
187 |
+
"Use this to continue training if output_dir points to a checkpoint directory."
|
188 |
+
)
|
189 |
+
},
|
190 |
+
)
|
191 |
+
|
192 |
+
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
193 |
+
do_eval: bool = field(
|
194 |
+
default=False, metadata={"help": "Whether to run eval on the dev set."}
|
195 |
+
)
|
196 |
+
|
197 |
+
per_device_train_batch_size: int = field(
|
198 |
+
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
|
199 |
+
)
|
200 |
+
per_device_eval_batch_size: int = field(
|
201 |
+
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
202 |
+
)
|
203 |
+
|
204 |
+
gradient_accumulation_steps: int = field(
|
205 |
+
default=1,
|
206 |
+
metadata={
|
207 |
+
"help": "Number of updates steps to accumulate before performing a backward/update pass."
|
208 |
+
},
|
209 |
+
)
|
210 |
+
|
211 |
+
learning_rate: float = field(
|
212 |
+
default=5e-5, metadata={"help": "The initial learning rate."}
|
213 |
+
)
|
214 |
+
adafactor: bool = field(
|
215 |
+
default=False,
|
216 |
+
metadata={"help": "Whether or not to replace AdamW by Adafactor."},
|
217 |
+
)
|
218 |
+
weight_decay: float = field(
|
219 |
+
default=None, metadata={"help": "Weight decay if we apply some."}
|
220 |
+
)
|
221 |
+
adam_beta1: float = field(
|
222 |
+
default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}
|
223 |
+
)
|
224 |
+
adam_beta2: float = field(
|
225 |
+
default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}
|
226 |
+
)
|
227 |
+
adam_epsilon: float = field(
|
228 |
+
default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
|
229 |
+
)
|
230 |
+
max_grad_norm: float = field(
|
231 |
+
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
232 |
+
)
|
233 |
+
use_decay: bool = field(
|
234 |
+
default=False,
|
235 |
+
metadata={"help": "Whether to use decay in the learning rate scheduler."},
|
236 |
+
)
|
237 |
+
|
238 |
+
num_train_epochs: float = field(
|
239 |
+
default=3.0, metadata={"help": "Total number of training epochs to perform."}
|
240 |
+
)
|
241 |
+
warmup_steps: int = field(
|
242 |
+
default=0, metadata={"help": "Linear warmup over warmup_steps."}
|
243 |
+
)
|
244 |
+
|
245 |
+
logging_steps: int = field(
|
246 |
+
default=40, metadata={"help": "Log every X updates steps."}
|
247 |
+
)
|
248 |
+
eval_steps: int = field(
|
249 |
+
default=400, metadata={"help": "Run an evaluation every X steps."}
|
250 |
+
)
|
251 |
+
save_steps: int = field(
|
252 |
+
default=4000, metadata={"help": "Save checkpoint every X updates steps."}
|
253 |
)
|
254 |
log_model: bool = field(
|
255 |
default=False,
|
256 |
+
metadata={"help": "Log model to wandb at `save_steps` frequency."},
|
257 |
)
|
258 |
|
259 |
+
seed: int = field(
|
260 |
+
default=42,
|
261 |
+
metadata={"help": "Random seed that will be set at the beginning of training."},
|
262 |
+
)
|
263 |
+
|
264 |
+
push_to_hub: bool = field(
|
265 |
+
default=False,
|
266 |
+
metadata={
|
267 |
+
"help": "Whether or not to upload the trained model to the model hub after training."
|
268 |
+
},
|
269 |
+
)
|
270 |
+
|
271 |
+
resume_from_wandb_checkpoint: Optional[str] = field(
|
272 |
+
default=None,
|
273 |
+
metadata={"help": "The reference to a wandb artifact for resuming training."},
|
274 |
+
)
|
275 |
|
276 |
|
277 |
class TrainState(train_state.TrainState):
|
|
|
378 |
|
379 |
|
380 |
def main():
|
381 |
+
# See all possible arguments by passing the --help flag to this script.
|
|
|
|
|
|
|
382 |
parser = HfArgumentParser(
|
383 |
(ModelArguments, DataTrainingArguments, TrainingArguments)
|
384 |
)
|
|
|
442 |
config=parser.parse_args(),
|
443 |
)
|
444 |
|
445 |
+
if training_args.resume_from_wandb_checkpoint is not None:
|
446 |
+
artifact = wandb.run.use_artifact(training_args.resume_from_wandb_checkpoint)
|
447 |
artifact_dir = artifact.download()
|
448 |
|
449 |
# load model
|
|
|
658 |
learning_rate_fn = create_learning_rate_fn(
|
659 |
training_args.warmup_steps,
|
660 |
training_args.learning_rate,
|
661 |
+
training_args.use_decay,
|
662 |
num_train_steps,
|
663 |
)
|
664 |
|
|
|
691 |
learning_rate=learning_rate_fn,
|
692 |
weight_decay_rate=training_args.weight_decay,
|
693 |
weight_decay_mask=decay_mask_fn,
|
694 |
+
clipping_threshold=training_args.max_grad_norm,
|
695 |
)
|
696 |
else:
|
697 |
optimizer = optax.adamw(
|
|
|
716 |
tx=optimizer,
|
717 |
dropout_rng=dropout_rng,
|
718 |
)
|
719 |
+
if training_args.resume_from_wandb_checkpoint is not None:
|
720 |
# restore optimizer state and other parameters
|
721 |
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
722 |
state = state.restore_state(artifact_dir)
|
|
|
856 |
with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
|
857 |
f.write(to_bytes(opt_state))
|
858 |
state_dict = {
|
859 |
+
k: jax.device_get(unreplicate(getattr(state, k))).item()
|
860 |
for k in ["step", "epoch", "train_time", "train_samples"]
|
861 |
}
|
862 |
with (Path(training_args.output_dir) / "training_state.json").open(
|
|
|
868 |
)
|
869 |
|
870 |
# save to W&B
|
871 |
+
if training_args.log_model:
|
872 |
# save some space
|
873 |
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
874 |
c.cleanup(wandb.util.from_human_size("10GB"))
|
|
|
951 |
)
|
952 |
step = unreplicate(state.step)
|
953 |
|
954 |
+
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
955 |
# log metrics
|
956 |
wandb_log(unreplicate(train_metric), step=step, prefix="train")
|
957 |
# log state parameters
|