program: run_seq2seq_flax.py entity: wandb project: hf-flax-dalle-mini method: random metric: name: eval/loss goal: minimize parameters: learning_rate: distribution: log_uniform # from exp(min) to exp(max), ie 5e-5 to 5e-3 on log scale min: -9.9 max: -5.3 gradient_accumulation_steps: value: 8 warmup_steps: # in term of optimization steps so multiplied by gradient accumulation value: 125 command: - python3 - ${program} - "--train_file" - "/data/CC12M/encoded-small-train.tsv" - "--validation_file" - "/data/CC12M/encoded-small-valid.tsv" - "--output_dir" - "./output_sweep" - "--overwrite_output_dir" - "--adafactor" - "--num_train_epochs" - 1 - "--max_train_samples" - 1500000 - "--per_device_train_batch_size" - 56 - "--per_device_eval_batch_size" - 56 - "--preprocessing_num_workers" - 80 - "--no_decay" - "--do_train" - "--do_eval" - ${args}