yhavinga commited on
Commit
1d4a13a
1 Parent(s): c7a2ca7

Saving weights and logs of step 1200

Browse files
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8e17e7235a47f50d90b1e481dcb09e9727dfe9a345b6cc36b8cc1cfd6a583c81
3
  size 891548548
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1370699db9ee8980b9d18ba78ab3c7bacbf64455af8b4d767abadaf4e1c6a466
3
  size 891548548
run_t5_mlm_flax_custom_dataset.py CHANGED
@@ -703,6 +703,13 @@ if __name__ == "__main__":
703
  else:
704
  model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
705
 
 
 
 
 
 
 
 
706
  # Data collator
707
  # This one will take care of randomly masking the tokens.
708
  data_collator = FlaxDataCollatorForT5MLM(
 
703
  else:
704
  model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
705
 
706
+
707
+ # def to_bf16(t):
708
+ # return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
709
+ #
710
+ #
711
+ # model.params = to_bf16(model.params)
712
+
713
  # Data collator
714
  # This one will take care of randomly masking the tokens.
715
  data_collator = FlaxDataCollatorForT5MLM(
runs/Jul10_08-38-10_t1v-n-0e7426e8-w-0/events.out.tfevents.1625906314.t1v-n-0e7426e8-w-0.25839.3.v2 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d1a444a03c97dd08a17796eec5b0c22674d96d4475fcb3dc8e47d8c3ec25db74
3
- size 136359
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b87fa89d0ac5eeabdea48a6a8250033be187061f5d0f1635b1d3f57ce6c7daaf
3
+ size 181839
streaming_dataset_filter_test.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from clean import clean_text
2
+
3
+ from datasets import load_dataset
4
+
5
+ dataset_v0 = load_dataset('oscar', "unshuffled_deduplicated_nl", split='train', streaming=True)
6
+
7
+
8
+ def f(obj):
9
+ obj["text"] = clean_text(obj["text"])
10
+ return obj
11
+
12
+
13
+ dataset_v1 = dataset_v0.map(f)
14
+ it = iter(dataset_v0)
15
+
16
+ print(next(it))
17
+ print(next(it))
18
+ print(next(it))
19
+
20
+ it = iter(dataset_v1)
21
+
22
+ print(next(it))
23
+ print(next(it))
24
+ print(next(it))