yhavinga commited on
Commit
04b230f
1 Parent(s): eb8efd2

Add pytorch model

Browse files
Files changed (3) hide show
  1. config.json +2 -1
  2. flax_to_pytorch.py +26 -0
  3. pytorch_model.bin +3 -0
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "/home/patrick/hugging_face/t5/t5-v1_1-base",
3
  "architectures": [
4
  "T5ForConditionalGeneration"
5
  ],
@@ -21,6 +21,7 @@
21
  "pad_token_id": 0,
22
  "relative_attention_num_buckets": 32,
23
  "tie_word_embeddings": false,
 
24
  "transformers_version": "4.13.0",
25
  "use_cache": true,
26
  "vocab_size": 32103
 
1
  {
2
+ "_name_or_path": ".",
3
  "architectures": [
4
  "T5ForConditionalGeneration"
5
  ],
 
21
  "pad_token_id": 0,
22
  "relative_attention_num_buckets": 32,
23
  "tie_word_embeddings": false,
24
+ "torch_dtype": "float32",
25
  "transformers_version": "4.13.0",
26
  "use_cache": true,
27
  "vocab_size": 32103
flax_to_pytorch.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import jax.numpy as jnp
4
+ from transformers import AutoTokenizer
5
+ from transformers import FlaxT5ForConditionalGeneration
6
+ from transformers import T5ForConditionalGeneration
7
+ tokenizer = AutoTokenizer.from_pretrained(".")
8
+ model_fx = FlaxT5ForConditionalGeneration.from_pretrained(".")
9
+ model_pt = T5ForConditionalGeneration.from_pretrained(".", from_flax=True)
10
+ model_pt.save_pretrained("./")
11
+ text = "Hoe gaat het?"
12
+ e_input_ids_fx = tokenizer(text, return_tensors="np", padding=True, max_length=128, truncation=True)
13
+ d_input_ids_fx = jnp.ones((e_input_ids_fx.input_ids.shape[0], 1), dtype="i4") * model_fx.config.decoder_start_token_id
14
+ e_input_ids_pt = tokenizer(text, return_tensors="pt", padding=True, max_length=128, truncation=True)
15
+ d_input_ids_pt = np.ones((e_input_ids_pt.input_ids.shape[0], 1), dtype="i4") * model_pt.config.decoder_start_token_id
16
+ print(e_input_ids_fx)
17
+ print(d_input_ids_fx)
18
+ print()
19
+ encoder_pt = model_fx.encode(**e_input_ids_pt)
20
+ decoder_pt = model_fx.decode(d_input_ids_pt, encoder_pt)
21
+ logits_pt = decoder_pt.logits
22
+ print(logits_pt)
23
+ encoder_fx = model_fx.encode(**e_input_ids_fx)
24
+ decoder_fx = model_fx.decode(d_input_ids_fx, encoder_fx)
25
+ logits_fx = decoder_fx.logits
26
+ print(logits_fx)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa8b87f8bb924ddaf9823ed6c9ed8f57adbee415b398049da58ddbe36997cf9a
3
+ size 990280781