Patrick von Platen commited on
Commit
4023cca
1 Parent(s): 473506b

run forward

Browse files
Files changed (1) hide show
  1. run_forward_gpt2_large.py +2 -2
run_forward_gpt2_large.py CHANGED
@@ -4,12 +4,12 @@ from flax.jax_utils import replicate
4
  from jax import jit, pmap
5
  import numpy as np
6
 
7
- model = FlaxGPT2LMHeadModel.from_pretrained("distilgpt2")
8
  dummy_inputs = np.array(4 * [256 * [1]], dtype=np.int32)
9
 
10
 
11
  def run_forward(inputs, params):
12
- return model(inputs, params).logits
13
 
14
 
15
  jitted_forward = jit(run_forward)
 
4
  from jax import jit, pmap
5
  import numpy as np
6
 
7
+ model = FlaxGPT2LMHeadModel.from_pretrained("gpt2-large")
8
  dummy_inputs = np.array(4 * [256 * [1]], dtype=np.int32)
9
 
10
 
11
  def run_forward(inputs, params):
12
+ return model(inputs, params=params).logits
13
 
14
 
15
  jitted_forward = jit(run_forward)