boris commited on
Commit
074c5e1
1 Parent(s): 9ed6378

feat: log epoch + check params

Browse files
dev/seq2seq/do_big_run.sh CHANGED
@@ -1,16 +1,16 @@
1
  python run_seq2seq_flax.py \
2
- --max_source_length 128 \
3
  --dataset_repo_or_path dalle-mini/encoded \
4
  --train_file **/train/*/*.jsonl \
5
  --validation_file **/valid/*/*.jsonl \
 
 
6
  --streaming \
7
- --len_train 1000000 \
8
- --len_eval 100 \
9
  --output_dir output \
10
  --per_device_train_batch_size 56 \
11
  --per_device_eval_batch_size 56 \
12
  --preprocessing_num_workers 80 \
13
- --warmup_steps 250 \
14
  --gradient_accumulation_steps 8 \
15
  --do_train \
16
  --do_eval \
 
1
  python run_seq2seq_flax.py \
 
2
  --dataset_repo_or_path dalle-mini/encoded \
3
  --train_file **/train/*/*.jsonl \
4
  --validation_file **/valid/*/*.jsonl \
5
+ --len_train 42684248 \
6
+ --len_eval 34328 \
7
  --streaming \
8
+ --normalize_text \
 
9
  --output_dir output \
10
  --per_device_train_batch_size 56 \
11
  --per_device_eval_batch_size 56 \
12
  --preprocessing_num_workers 80 \
13
+ --warmup_steps 500 \
14
  --gradient_accumulation_steps 8 \
15
  --do_train \
16
  --do_eval \
dev/seq2seq/do_small_run.sh CHANGED
@@ -2,9 +2,9 @@ python run_seq2seq_flax.py \
2
  --dataset_repo_or_path dalle-mini/encoded \
3
  --train_file **/train/*/*.jsonl \
4
  --validation_file **/valid/*/*.jsonl \
 
 
5
  --streaming \
6
- --len_train 1000000 \
7
- --len_eval 1000 \
8
  --output_dir output \
9
  --per_device_train_batch_size 56 \
10
  --per_device_eval_batch_size 56 \
@@ -15,5 +15,5 @@ python run_seq2seq_flax.py \
15
  --do_eval \
16
  --adafactor \
17
  --num_train_epochs 1 \
18
- --max_train_samples 20000 \
19
  --learning_rate 0.005
 
2
  --dataset_repo_or_path dalle-mini/encoded \
3
  --train_file **/train/*/*.jsonl \
4
  --validation_file **/valid/*/*.jsonl \
5
+ --len_train 42684248 \
6
+ --len_eval 34328 \
7
  --streaming \
 
 
8
  --output_dir output \
9
  --per_device_train_batch_size 56 \
10
  --per_device_eval_batch_size 56 \
 
15
  --do_eval \
16
  --adafactor \
17
  --num_train_epochs 1 \
18
+ --max_train_samples 10000 \
19
  --learning_rate 0.005
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -138,16 +138,6 @@ class DataTrainingArguments:
138
  Arguments pertaining to what data we are going to input our model for training and eval.
139
  """
140
 
141
- dataset_name: Optional[str] = field(
142
- default=None,
143
- metadata={"help": "The name of the dataset to use (via the datasets library)."},
144
- )
145
- dataset_config_name: Optional[str] = field(
146
- default=None,
147
- metadata={
148
- "help": "The configuration name of the dataset to use (via the datasets library)."
149
- },
150
- )
151
  text_column: Optional[str] = field(
152
  default="caption",
153
  metadata={
@@ -260,14 +250,10 @@ class DataTrainingArguments:
260
  )
261
 
262
  def __post_init__(self):
263
- if (
264
- self.dataset_name is None
265
- and self.train_file is None
266
- and self.validation_file is None
267
- ):
268
- raise ValueError(
269
- "Need either a dataset name or a training/validation file."
270
- )
271
  else:
272
  if self.train_file is not None:
273
  extension = self.train_file.split(".")[-1]
@@ -287,6 +273,10 @@ class DataTrainingArguments:
287
  ], "`validation_file` should be a tsv, csv or json file."
288
  if self.val_max_target_length is None:
289
  self.val_max_target_length = self.max_target_length
 
 
 
 
290
 
291
 
292
  class TrainState(train_state.TrainState):
@@ -467,18 +457,6 @@ def main():
467
  "Use --overwrite_output_dir to overcome."
468
  )
469
 
470
- # Set up wandb run
471
- wandb.init(
472
- entity="dalle-mini",
473
- project="dalle-mini",
474
- job_type="Seq2Seq",
475
- config=parser.parse_args(),
476
- )
477
-
478
- # set default x-axis as 'train/step'
479
- wandb.define_metric("train/step")
480
- wandb.define_metric("*", step_metric="train/step")
481
-
482
  # Make one log on every process with the configuration for debugging.
483
  pylogging.basicConfig(
484
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -528,6 +506,18 @@ def main():
528
 
529
  return step, optimizer_step, opt_state
530
 
 
 
 
 
 
 
 
 
 
 
 
 
531
  if model_args.from_checkpoint is not None:
532
  artifact = wandb.run.use_artifact(model_args.from_checkpoint)
533
  artifact_dir = artifact.download()
@@ -1006,6 +996,7 @@ def main():
1006
 
1007
  for epoch in epochs:
1008
  # ======================== Training ================================
 
1009
 
1010
  # Create sampling rng
1011
  rng, input_rng = jax.random.split(rng)
 
138
  Arguments pertaining to what data we are going to input our model for training and eval.
139
  """
140
 
 
 
 
 
 
 
 
 
 
 
141
  text_column: Optional[str] = field(
142
  default="caption",
143
  metadata={
 
250
  )
251
 
252
  def __post_init__(self):
253
+ if self.dataset_repo_or_path is None:
254
+ raise ValueError("Need a dataset repository or path.")
255
+ if self.train_file is None or self.validation_file is None:
256
+ raise ValueError("Need training/validation file.")
 
 
 
 
257
  else:
258
  if self.train_file is not None:
259
  extension = self.train_file.split(".")[-1]
 
273
  ], "`validation_file` should be a tsv, csv or json file."
274
  if self.val_max_target_length is None:
275
  self.val_max_target_length = self.max_target_length
276
+ if self.streaming and (self.len_train is None or self.len_eval is None):
277
+ raise ValueError(
278
+ "Streaming requires providing length of training and validation datasets"
279
+ )
280
 
281
 
282
  class TrainState(train_state.TrainState):
 
457
  "Use --overwrite_output_dir to overcome."
458
  )
459
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  # Make one log on every process with the configuration for debugging.
461
  pylogging.basicConfig(
462
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
 
506
 
507
  return step, optimizer_step, opt_state
508
 
509
+ # Set up wandb run
510
+ wandb.init(
511
+ entity="dalle-mini",
512
+ project="dalle-mini",
513
+ job_type="Seq2Seq",
514
+ config=parser.parse_args(),
515
+ )
516
+
517
+ # set default x-axis as 'train/step'
518
+ wandb.define_metric("train/step")
519
+ wandb.define_metric("*", step_metric="train/step")
520
+
521
  if model_args.from_checkpoint is not None:
522
  artifact = wandb.run.use_artifact(model_args.from_checkpoint)
523
  artifact_dir = artifact.download()
 
996
 
997
  for epoch in epochs:
998
  # ======================== Training ================================
999
+ wandb_log({"train/epoch": epoch}, step=global_step)
1000
 
1001
  # Create sampling rng
1002
  rng, input_rng = jax.random.split(rng)