Rolv-Arild commited on
Commit
a5d5245
1 Parent(s): e71c2b1

Saving weights and logs of step 35000

Browse files
convert.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import tempfile
3
+
4
+ import jax
5
+ from jax import numpy as jnp
6
+ from transformers import AutoTokenizer, FlaxBertForMaskedLM, BertForMaskedLM
7
+
8
+
9
+ def to_f32(t):
10
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
11
+
12
+
13
+ def main():
14
+ # Saving extra files from config.json and tokenizer.json files
15
+ tokenizer = AutoTokenizer.from_pretrained("./")
16
+ tokenizer.save_pretrained("./")
17
+
18
+ # Temporary saving bfloat16 Flax model into float32
19
+ tmp = tempfile.mkdtemp()
20
+ flax_model = FlaxBertForMaskedLM.from_pretrained("./")
21
+ flax_model.params = to_f32(flax_model.params)
22
+ flax_model.save_pretrained(tmp)
23
+ # Converting float32 Flax to PyTorch
24
+ model = BertForMaskedLM.from_pretrained(tmp, from_flax=True)
25
+ model.save_pretrained("./", save_config=False)
26
+
27
+
28
+ if __name__ == "__main__":
29
+ main()
events.out.tfevents.1649146360.t1v-n-eedfb410-w-0.232238.0.v2 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dd9a16b3f835bda324ec54c6d055ddbef0ff69304226141f833adf5e119b9ee8
3
- size 5061593
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b0b0a97b939b6c937442a77a2798640d608b058f3eca8584ff7cff0ddc53098d
3
+ size 5211435
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9e221febfb233514ab4836cc21f9905fffada2463fa5f5db7562aa39eec7e96a
3
  size 711905363
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca4841024e3ac944e0ca9eb204cea434ed7ec5ccda7f8b9c8028ad1ba103acd7
3
  size 711905363