pere commited on
Commit
37ada0d
1 Parent(s): 2d0cbe0

after restart

Browse files
events.out.tfevents.1628300146.t1v-n-1a0a7c50-w-0.1616525.3.v2 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:94e7c67609c3f27b9c3650ba91ed66f4f1a3106cf7c7c92c998bc28b399653d3
3
- size 52176063
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc846e02789f45a7b1ada21713b56d4496b6dfa15c056619f2a5b46eb22ecdc0
3
+ size 52250633
flax_model_backup350k.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7a1b3c986f8f90bfb67978326d950efc15e64ad9b6d5f4884bcde1013a65968
3
+ size 1100762015
run_streaming.sh CHANGED
@@ -3,6 +3,7 @@
3
  --model_type="t5" \
4
  --config_name="./" \
5
  --tokenizer_name="./" \
 
6
  --dataset_name="pere/norwegian_colossal_corpus_v2_short100k" \
7
  --max_seq_length="512" \
8
  --weight_decay="0.01" \
3
  --model_type="t5" \
4
  --config_name="./" \
5
  --tokenizer_name="./" \
6
+ --model_name_or_path="./" \
7
  --dataset_name="pere/norwegian_colossal_corpus_v2_short100k" \
8
  --max_seq_length="512" \
9
  --weight_decay="0.01" \
run_t5_mlm_flax_streaming.py CHANGED
@@ -552,16 +552,17 @@ if __name__ == "__main__":
552
  rng = jax.random.PRNGKey(training_args.seed)
553
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
554
 
555
- model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
556
-
557
- #if model_args.model_name_or_path:
558
- # model = FlaxT5ForConditionalGeneration.from_pretrained(
559
- # model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
560
- # )
561
- #else:
562
- # model = FlaxT5ForConditionalGeneration.from_pretrained(
563
- # config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
564
- # )
 
565
 
566
 
567
  # Data collator
552
  rng = jax.random.PRNGKey(training_args.seed)
553
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
554
 
555
+ #Pere changed 13 august
556
+ #model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
557
+
558
+ if model_args.model_name_or_path:
559
+ model = FlaxT5ForConditionalGeneration.from_pretrained(
560
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
561
+ )
562
+ else:
563
+ model = FlaxT5ForConditionalGeneration.from_pretrained(
564
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
565
+ )
566
 
567
 
568
  # Data collator