pere commited on
Commit
bd3bd38
1 Parent(s): f5212b6
Files changed (3) hide show
  1. flax_model.msgpack +0 -3
  2. run.sh +7 -8
  3. run_mlm_flax_stream.py +7 -5
flax_model.msgpack DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8a04d8032d9ff18b1b727c98ee3bea82de1d908107d257d94486dba8650f5680
3
- size 1113187999
 
 
 
 
run.sh CHANGED
@@ -7,17 +7,16 @@ python run_mlm_flax_stream.py \
7
  --weight_decay="0.01" \
8
  --per_device_train_batch_size="62" \
9
  --per_device_eval_batch_size="16" \
10
- --learning_rate="4e-4" \
11
- --warmup_steps="1000" \
12
  --overwrite_output_dir \
13
- --num_train_steps="10000" \
14
  --adam_beta1="0.9" \
15
  --adam_beta2="0.98" \
16
- --logging_steps="50" \
17
- --save_steps="50" \
18
- --eval_steps="50" \
19
  --output_dir="./" \
20
  --dtype="bfloat16" \
21
- --push_to_hub_organization="NbAiLab" \
22
- --push_to_hub_model_id="nb-roberta-base-scandi" \
23
  --push_to_hub
 
7
  --weight_decay="0.01" \
8
  --per_device_train_batch_size="62" \
9
  --per_device_eval_batch_size="16" \
10
+ --learning_rate="3e-4" \
11
+ --warmup_steps="25000" \
12
  --overwrite_output_dir \
13
+ --num_train_steps="250000" \
14
  --adam_beta1="0.9" \
15
  --adam_beta2="0.98" \
16
+ --logging_steps="1000" \
17
+ --save_steps="1000" \
18
+ --eval_steps="1000" \
19
  --output_dir="./" \
20
  --dtype="bfloat16" \
21
+ --hub_model_id="NbAiLab/nb-roberta-base-scandi" \
 
22
  --push_to_hub
run_mlm_flax_stream.py CHANGED
@@ -42,12 +42,14 @@ import optax
42
  from flax import jax_utils, traverse_util
43
  from flax.training import train_state
44
  from flax.training.common_utils import get_metrics, onehot, shard
 
45
  from transformers import (
46
  CONFIG_MAPPING,
47
  FLAX_MODEL_FOR_MASKED_LM_MAPPING,
48
  AutoConfig,
49
  AutoTokenizer,
50
  FlaxAutoModelForMaskedLM,
 
51
  HfArgumentParser,
52
  PreTrainedTokenizerBase,
53
  TensorType,
@@ -650,18 +652,18 @@ if __name__ == "__main__":
650
  model.save_pretrained(
651
  training_args.output_dir,
652
  params=params,
653
- push_to_hub=training_args.push_to_hub,
654
- commit_message=f"Saving weights and logs of step {step+1}",
655
  )
656
- print(f"Saving weights and logs of step {step+1}. \nThe result is saved to {training_args.output_folder} by worker {jax.process_index()}.")
 
 
 
657
  except:
658
- breakpoint()
659
  model.save_pretrained(
660
  training_args.output_dir,
661
  params=params
662
  )
663
  print("Problems pushing this to the hub. The bug should be fixed.")
 
664
 
665
  # update tqdm bar
666
  steps.update(1)
667
-
 
42
  from flax import jax_utils, traverse_util
43
  from flax.training import train_state
44
  from flax.training.common_utils import get_metrics, onehot, shard
45
+ from huggingface_hub import Repository
46
  from transformers import (
47
  CONFIG_MAPPING,
48
  FLAX_MODEL_FOR_MASKED_LM_MAPPING,
49
  AutoConfig,
50
  AutoTokenizer,
51
  FlaxAutoModelForMaskedLM,
52
+ AutoModelForMaskedLM,
53
  HfArgumentParser,
54
  PreTrainedTokenizerBase,
55
  TensorType,
 
652
  model.save_pretrained(
653
  training_args.output_dir,
654
  params=params,
 
 
655
  )
656
+
657
+ repo = Repository(local_dir=training_args.output_dir)
658
+ repo.push_to_hub(commit_message=f"Saving weights and logs of step {step+1}", blocking=False)
659
+
660
  except:
 
661
  model.save_pretrained(
662
  training_args.output_dir,
663
  params=params
664
  )
665
  print("Problems pushing this to the hub. The bug should be fixed.")
666
+ breakpoint()
667
 
668
  # update tqdm bar
669
  steps.update(1)