File size: 1,748 Bytes
522b344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import numpy as np
import jax.numpy as jnp

from transformers import AutoTokenizer
from transformers import FlaxT5ForConditionalGeneration
from transformers import T5ForConditionalGeneration

tokenizer = AutoTokenizer.from_pretrained("./")
model_fx = FlaxT5ForConditionalGeneration.from_pretrained("./")
model_pt = T5ForConditionalGeneration.from_pretrained("./", from_flax=True)
model_pt.save_pretrained("./")

text = """Het is nog niet duidelijk
welke hoogte het water nabij Venlo heeft bereikt. De hoogwaterpiek is vermoedelijk iets vlakker dan verwacht, maar blijft langer aanhouden, tot zondag 19.00 uur. Vooralsnog zijn er weinig meldingen over schade of overlast, meldt een woordvoerder van Veiligheidsregio Limburg-Noord zaterdag aan NU.nl. Via het Nationaal Rampenfonds is binnen één etmaal al 1 miljoen euro opgehaald voor gedupeerden.
"""
e_input_ids_fx = tokenizer(text, return_tensors="np", padding=True, max_length=128, truncation=True)
d_input_ids_fx = jnp.ones((e_input_ids_fx.input_ids.shape[0], 1), dtype="i4") * model_fx.config.decoder_start_token_id

print(e_input_ids_fx)
print(d_input_ids_fx)

e_input_ids_pt = tokenizer(text, return_tensors="pt", padding=True, max_length=128, truncation=True)
d_input_ids_pt = np.ones((e_input_ids_pt.input_ids.shape[0], 1), dtype="i4") * model_pt.config.decoder_start_token_id


print(e_input_ids_pt)
print(d_input_ids_pt)

print()

encoder_pt = model_fx.encode(**e_input_ids_pt)
decoder_pt = model_fx.decode(d_input_ids_pt, encoder_pt)
logits_pt = decoder_pt.logits
print(f"Pytorch output: {logits_pt}")

encoder_fx = model_fx.encode(**e_input_ids_fx)
decoder_fx = model_fx.decode(d_input_ids_fx, encoder_fx)
logits_fx = decoder_fx.logits
print(f"Flax output: {logits_fx}")