cahya commited on
Commit
2fe1133
1 Parent(s): 4f4b312

update jax converter

Browse files
Files changed (2) hide show
  1. jax2torch.py +6 -2
  2. run_pretraining.sh +1 -0
jax2torch.py CHANGED
@@ -1,8 +1,12 @@
1
- from transformers import GPT2Config, GPT2LMHeadModel
2
 
3
  '''
4
- This is a script to convert the Jax model to Pytorch model
5
  '''
6
 
7
  model = GPT2LMHeadModel.from_pretrained(".", from_flax=True)
8
  model.save_pretrained(".")
 
 
 
 
 
1
+ from transformers import AutoTokenizer, GPT2LMHeadModel
2
 
3
  '''
4
+ This is a script to convert the Jax model and the tokenizer to Pytorch model
5
  '''
6
 
7
  model = GPT2LMHeadModel.from_pretrained(".", from_flax=True)
8
  model.save_pretrained(".")
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained(".")
11
+ tokenizer.save_pretrained(".")
12
+
run_pretraining.sh CHANGED
@@ -4,6 +4,7 @@ export WANDB_PROJECT="hf-flax-gpt2-indonesian"
4
  export WANDB_LOG_MODEL="true"
5
 
6
  ./run_clm_flax.py \
 
7
  --output_dir="${MODEL_DIR}" \
8
  --model_type="gpt2" \
9
  --config_name="${MODEL_DIR}" \
 
4
  export WANDB_LOG_MODEL="true"
5
 
6
  ./run_clm_flax.py \
7
+ --model_name_or_path="flax_model.msgpack" \
8
  --output_dir="${MODEL_DIR}" \
9
  --model_type="gpt2" \
10
  --config_name="${MODEL_DIR}" \