patrickvonplaten commited on
Commit
473506b
1 Parent(s): 141e415

add forward function

Browse files
Files changed (1) hide show
  1. run_forward_gpt2_large.py +27 -0
run_forward_gpt2_large.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from transformers import FlaxGPT2LMHeadModel
3
+ from flax.jax_utils import replicate
4
+ from jax import jit, pmap
5
+ import numpy as np
6
+
7
+ model = FlaxGPT2LMHeadModel.from_pretrained("distilgpt2")
8
+ dummy_inputs = np.array(4 * [256 * [1]], dtype=np.int32)
9
+
10
+
11
+ def run_forward(inputs, params):
12
+ return model(inputs, params).logits
13
+
14
+
15
+ jitted_forward = jit(run_forward)
16
+
17
+ ## simple forward
18
+ logits = jitted_forward(dummy_inputs, model.params)
19
+
20
+
21
+ ## parallel forward
22
+ p_forward = pmap(run_forward)
23
+
24
+ p_params = replicate(model.params)
25
+ p_inptus = replicate(dummy_inputs)
26
+
27
+ logits = p_forward(p_inptus, p_params)