aapot commited on
Commit
e52e3c0
·
1 Parent(s): 60b6bc0

Add 128 pytorch model

Browse files
events.out.tfevents.1639865567.t1v-n-8eba1090-w-0.1317510.0.v2 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ce839b5e41f8765fb370f9d951ce8337ae17c2685c247f1273d66058f88f093b
3
- size 40215815
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1ab87187bbc8b9b5e02916e5b331c99f73c0112cf4a39c025c8d103da252fad
3
+ size 40812375
flax_model_to_pytorch.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import RobertaForMaskedLM, FlaxRobertaForMaskedLM, AutoTokenizer
2
+ import torch
3
+ import numpy as np
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+ def to_f32(t):
8
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
9
+
10
+ jax.config.update('jax_platform_name', 'cpu')
11
+ MODEL_PATH = "./"
12
+ model = FlaxRobertaForMaskedLM.from_pretrained(MODEL_PATH)
13
+ model.params = to_f32(model.params)
14
+ model.save_pretrained(MODEL_PATH)
15
+
16
+ pt_model = RobertaForMaskedLM.from_pretrained(
17
+ MODEL_PATH, from_flax=True).to('cpu')
18
+
19
+ input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
20
+ input_ids_pt = torch.tensor(input_ids)
21
+
22
+ logits_pt = pt_model(input_ids_pt).logits
23
+ print(logits_pt)
24
+ logits_fx = model(input_ids).logits
25
+ print(logits_fx)
26
+
27
+ pt_model.save_pretrained(MODEL_PATH)
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:741983451ebd3f767044f9f28f8ad4621e946e22b9dac19ea0612e304300c307
3
  size 1421807019
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63a56183485d2c3a93eaa11d851da6c60c40d60c08704b68c9d1975bb3efbe40
3
  size 1421807019
start_train.sh CHANGED
@@ -5,15 +5,14 @@ export USE_TORCH=0
5
  python3 run_mlm_flax.py \
6
  --output_dir="./" \
7
  --model_name_or_path="./" \
8
- --from_pytorch \
9
  --config_name="./" \
10
  --tokenizer_name="./" \
11
  --dataset_filepath="/researchdisk/training_dataset_full_deduplicated" \
12
- --max_seq_length="128" \
13
  --pad_to_max_length \
14
  --preprocessing_num_workers="64" \
15
- --per_device_train_batch_size="64" \
16
- --per_device_eval_batch_size="64" \
17
  --adam_beta1="0.9" \
18
  --adam_beta2="0.98" \
19
  --adam_epsilon="1e-6" \
 
5
  python3 run_mlm_flax.py \
6
  --output_dir="./" \
7
  --model_name_or_path="./" \
 
8
  --config_name="./" \
9
  --tokenizer_name="./" \
10
  --dataset_filepath="/researchdisk/training_dataset_full_deduplicated" \
11
+ --max_seq_length="512" \
12
  --pad_to_max_length \
13
  --preprocessing_num_workers="64" \
14
+ --per_device_train_batch_size="16" \
15
+ --per_device_eval_batch_size="16" \
16
  --adam_beta1="0.9" \
17
  --adam_beta2="0.98" \
18
  --adam_epsilon="1e-6" \