Spaces:
Running
Running
Merge pull request #9 from borisdayma/feat--wandb-search
Browse files- seq2seq/do_run.sh +1 -0
- seq2seq/run_seq2seq_flax.py +32 -12
seq2seq/do_run.sh
CHANGED
@@ -6,5 +6,6 @@ python run_seq2seq_flax.py \
|
|
6 |
--per_device_train_batch_size 24 \
|
7 |
--per_device_eval_batch_size 24 \
|
8 |
--preprocessing_num_workers 48 \
|
|
|
9 |
--do_train \
|
10 |
--do_eval \
|
|
|
6 |
--per_device_train_batch_size 24 \
|
7 |
--per_device_eval_batch_size 24 \
|
8 |
--preprocessing_num_workers 48 \
|
9 |
+
--warmup_steps 1000 \
|
10 |
--do_train \
|
11 |
--do_eval \
|
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -215,6 +215,13 @@ class DataTrainingArguments:
|
|
215 |
overwrite_cache: bool = field(
|
216 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
217 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
def __post_init__(self):
|
220 |
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
@@ -307,12 +314,12 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|
307 |
|
308 |
train_metrics = get_metrics(train_metrics)
|
309 |
for key, vals in train_metrics.items():
|
310 |
-
tag = f"
|
311 |
for i, val in enumerate(vals):
|
312 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
313 |
|
314 |
for metric_name, value in eval_metrics.items():
|
315 |
-
summary_writer.scalar(f"
|
316 |
|
317 |
|
318 |
def create_learning_rate_fn(
|
@@ -616,17 +623,24 @@ def main():
|
|
616 |
return traverse_util.unflatten_dict(flat_mask)
|
617 |
|
618 |
# create adam optimizer
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
627 |
|
628 |
# Setup train state
|
629 |
-
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=
|
630 |
|
631 |
# label smoothed cross entropy
|
632 |
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
@@ -718,6 +732,7 @@ def main():
|
|
718 |
|
719 |
train_time = 0
|
720 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
|
|
721 |
for epoch in epochs:
|
722 |
# ======================== Training ================================
|
723 |
train_start = time.time()
|
@@ -730,11 +745,16 @@ def main():
|
|
730 |
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
731 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
732 |
# train
|
733 |
-
for
|
|
|
734 |
batch = next(train_loader)
|
735 |
state, train_metric = p_train_step(state, batch)
|
736 |
train_metrics.append(train_metric)
|
737 |
|
|
|
|
|
|
|
|
|
738 |
train_time += time.time() - train_start
|
739 |
|
740 |
train_metric = unreplicate(train_metric)
|
|
|
215 |
overwrite_cache: bool = field(
|
216 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
217 |
)
|
218 |
+
log_interval: Optional[int] = field(
|
219 |
+
default=5,
|
220 |
+
metadata={
|
221 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
222 |
+
"value if set."
|
223 |
+
},
|
224 |
+
)
|
225 |
|
226 |
def __post_init__(self):
|
227 |
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
|
|
314 |
|
315 |
train_metrics = get_metrics(train_metrics)
|
316 |
for key, vals in train_metrics.items():
|
317 |
+
tag = f"train_epoch/{key}"
|
318 |
for i, val in enumerate(vals):
|
319 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
320 |
|
321 |
for metric_name, value in eval_metrics.items():
|
322 |
+
summary_writer.scalar(f"eval/{metric_name}", value, step)
|
323 |
|
324 |
|
325 |
def create_learning_rate_fn(
|
|
|
623 |
return traverse_util.unflatten_dict(flat_mask)
|
624 |
|
625 |
# create adam optimizer
|
626 |
+
if training_args.adafactor:
|
627 |
+
# We use the default parameters here to initialize adafactor,
|
628 |
+
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
629 |
+
optimizer = optax.adafactor(
|
630 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
631 |
+
)
|
632 |
+
else:
|
633 |
+
optimizer = optax.adamw(
|
634 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
635 |
+
b1=training_args.adam_beta1,
|
636 |
+
b2=training_args.adam_beta2,
|
637 |
+
eps=training_args.adam_epsilon,
|
638 |
+
weight_decay=training_args.weight_decay,
|
639 |
+
mask=decay_mask_fn,
|
640 |
+
)
|
641 |
|
642 |
# Setup train state
|
643 |
+
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
|
644 |
|
645 |
# label smoothed cross entropy
|
646 |
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
|
|
732 |
|
733 |
train_time = 0
|
734 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
735 |
+
global_step = 0
|
736 |
for epoch in epochs:
|
737 |
# ======================== Training ================================
|
738 |
train_start = time.time()
|
|
|
745 |
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
746 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
747 |
# train
|
748 |
+
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
749 |
+
global_step +=1
|
750 |
batch = next(train_loader)
|
751 |
state, train_metric = p_train_step(state, batch)
|
752 |
train_metrics.append(train_metric)
|
753 |
|
754 |
+
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
755 |
+
for k, v in unreplicate(train_metric).items():
|
756 |
+
wandb.log(f{'train/{k}': jax.device_get(v)}, step=global_step)
|
757 |
+
|
758 |
train_time += time.time() - train_start
|
759 |
|
760 |
train_metric = unreplicate(train_metric)
|