pere commited on
Commit
221024e
1 Parent(s): d20a6e0

new attempt

Browse files
Files changed (2) hide show
  1. run.sh +2 -1
  2. run_mlm_flax_stream.py +7 -2
run.sh CHANGED
@@ -1,7 +1,6 @@
1
 
2
  python run_mlm_flax_stream.py \
3
  --output_dir="./" \
4
- --hub_model_id="NbAiLab/nb-roberta-base-scandi" \
5
  --model_name_or_path="xlm-roberta-base" \
6
  --config_name="./config.json" \
7
  --tokenizer_name="./" \
@@ -13,6 +12,7 @@ python run_mlm_flax_stream.py \
13
  --learning_rate="4e-4" \
14
  --warmup_steps="1000" \
15
  --overwrite_output_dir \
 
16
  --num_train_steps="10000" \
17
  --adam_beta1="0.9" \
18
  --adam_beta2="0.98" \
@@ -20,4 +20,5 @@ python run_mlm_flax_stream.py \
20
  --save_steps="50" \
21
  --eval_steps="50" \
22
  --dtype="bfloat16" \
 
23
  --push_to_hub
 
1
 
2
  python run_mlm_flax_stream.py \
3
  --output_dir="./" \
 
4
  --model_name_or_path="xlm-roberta-base" \
5
  --config_name="./config.json" \
6
  --tokenizer_name="./" \
 
12
  --learning_rate="4e-4" \
13
  --warmup_steps="1000" \
14
  --overwrite_output_dir \
15
+ --use_auth_token \
16
  --num_train_steps="10000" \
17
  --adam_beta1="0.9" \
18
  --adam_beta2="0.98" \
 
20
  --save_steps="50" \
21
  --eval_steps="50" \
22
  --dtype="bfloat16" \
23
+ --push_to_hub_model_id="NbAiLab/nb-roberta-base-scandi" \
24
  --push_to_hub
run_mlm_flax_stream.py CHANGED
@@ -655,13 +655,18 @@ if __name__ == "__main__":
655
  )
656
  print(f"Saving weights and logs of step {step+1}. \nThe result is saved to {training_args.output_folder} by worker {jax.process_index()}.")
657
  except:
658
-
659
  model.save_pretrained(
660
  training_args.output_dir,
661
  params=params
662
  )
663
  print("Problems pushing this to the hub. The bug should be fixed.")
664
-
 
 
 
 
 
665
  # update tqdm bar
666
  steps.update(1)
667
 
 
655
  )
656
  print(f"Saving weights and logs of step {step+1}. \nThe result is saved to {training_args.output_folder} by worker {jax.process_index()}.")
657
  except:
658
+ breakpoint()
659
  model.save_pretrained(
660
  training_args.output_dir,
661
  params=params
662
  )
663
  print("Problems pushing this to the hub. The bug should be fixed.")
664
+ else:
665
+ model.save_pretrained(
666
+ training_args.output_dir,
667
+ params=params
668
+ )
669
+
670
  # update tqdm bar
671
  steps.update(1)
672