Patrick von Platen
commited on
Commit
•
4023cca
1
Parent(s):
473506b
run forward
Browse files
run_forward_gpt2_large.py
CHANGED
@@ -4,12 +4,12 @@ from flax.jax_utils import replicate
|
|
4 |
from jax import jit, pmap
|
5 |
import numpy as np
|
6 |
|
7 |
-
model = FlaxGPT2LMHeadModel.from_pretrained("
|
8 |
dummy_inputs = np.array(4 * [256 * [1]], dtype=np.int32)
|
9 |
|
10 |
|
11 |
def run_forward(inputs, params):
|
12 |
-
return model(inputs, params).logits
|
13 |
|
14 |
|
15 |
jitted_forward = jit(run_forward)
|
|
|
4 |
from jax import jit, pmap
|
5 |
import numpy as np
|
6 |
|
7 |
+
model = FlaxGPT2LMHeadModel.from_pretrained("gpt2-large")
|
8 |
dummy_inputs = np.array(4 * [256 * [1]], dtype=np.int32)
|
9 |
|
10 |
|
11 |
def run_forward(inputs, params):
|
12 |
+
return model(inputs, params=params).logits
|
13 |
|
14 |
|
15 |
jitted_forward = jit(run_forward)
|