versae commited on
Commit
9c5541b
1 Parent(s): 36b7dde

Adding base config and organizing configs

Browse files
configs/base/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "roberta",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "position_embedding_type": "absolute",
21
+ "transformers_version": "4.9.0.dev0",
22
+ "type_vocab_size": 1,
23
+ "use_cache": true,
24
+ "vocab_size": 50265
25
+ }
config.json → configs/large/config.json RENAMED
File without changes
run.sh CHANGED
@@ -1,9 +1,9 @@
1
  # From https://arxiv.org/pdf/1907.11692.pdf
2
- HUB_TOKEN=`cat $HOME/.huggingface/token`
3
  ./run_mlm_flax.py \
4
- --output_dir="./" \
5
  --model_type="roberta" \
6
- --config_name="./" \
7
  --tokenizer_name="./" \
8
  --dataset_name="mc4" \
9
  --dataset_config_name="es" \
@@ -25,7 +25,4 @@ HUB_TOKEN=`cat $HOME/.huggingface/token`
25
  --num_train_steps="500000" \
26
  --eval_steps="10000" \
27
  --logging_steps="500" \
28
- --dtype="bfloat16" \
29
- --push_to_hub_model_id="flax-community/bertin-roberta-large-spanish" \
30
- --push_to_hub_token="$HUB_TOKEN"
31
- --push_to_hub 2>&1 | tee run.log
1
  # From https://arxiv.org/pdf/1907.11692.pdf
2
+ python -c "import jax; print('TPUs', jax.device_count())"
3
  ./run_mlm_flax.py \
4
+ --output_dir="./outputs" \
5
  --model_type="roberta" \
6
+ --config_name="./configs/large" \
7
  --tokenizer_name="./" \
8
  --dataset_name="mc4" \
9
  --dataset_config_name="es" \
25
  --num_train_steps="500000" \
26
  --eval_steps="10000" \
27
  --logging_steps="500" \
28
+ --dtype="bfloat16" 2>&1 | tee run.log
 
 
 
run_stream.sh CHANGED
@@ -3,7 +3,7 @@ python -c "import jax; print('TPUs', jax.device_count())"
3
  python ./run_mlm_flax_stream.py \
4
  --output_dir="./outputs" \
5
  --model_type="roberta" \
6
- --config_name="./config-base.json" \
7
  --tokenizer_name="./" \
8
  --dataset_name="./mc4" \
9
  --dataset_config_name="es" \
3
  python ./run_mlm_flax_stream.py \
4
  --output_dir="./outputs" \
5
  --model_type="roberta" \
6
+ --config_name="./configs/base" \
7
  --tokenizer_name="./" \
8
  --dataset_name="./mc4" \
9
  --dataset_config_name="es" \