yhavinga commited on
Commit
342d7f4
1 Parent(s): 7cc8a21

Add pytorch model at 240k steps

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. flax_to_pytorch.py +22 -0
README.md CHANGED
@@ -14,7 +14,7 @@ datasets:
14
 
15
  Training details:
16
 
17
- * trained for 120k steps (24 dec 2021)
18
  * block size: 512
19
  * optimizer: adam, lr 8e-4, beta1 0.9, beta2 0.98
20
  * warmup 5000 steps
 
14
 
15
  Training details:
16
 
17
+ * trained for 240k steps (29 dec 2021)
18
  * block size: 512
19
  * optimizer: adam, lr 8e-4, beta1 0.9, beta2 0.98
20
  * warmup 5000 steps
flax_to_pytorch.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from transformers import AutoTokenizer
6
+ from transformers import FlaxGPT2LMHeadModel
7
+ from transformers import GPT2LMHeadModel
8
+ tokenizer = AutoTokenizer.from_pretrained(".")
9
+ tokenizer.pad_token = tokenizer.eos_token
10
+ model_fx = FlaxGPT2LMHeadModel.from_pretrained(".")
11
+ # def to_f32(t):
12
+ # return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
13
+ # model_fx.params = to_f32(model_fx.params)
14
+ # model_fx.save_pretrained("./fx")
15
+ model_pt = GPT2LMHeadModel.from_pretrained(".", from_flax=True)
16
+ model_pt.save_pretrained("./pt")
17
+ input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
18
+ input_ids_pt = torch.tensor(input_ids)
19
+ logits_pt = model_pt(input_ids_pt).logits
20
+ print(logits_pt)
21
+ logits_fx = model_fx(input_ids).logits
22
+ print(logits_fx)