File size: 618 Bytes
473506b
 
 
 
 
 
4023cca
473506b
 
 
 
4023cca
473506b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#!/usr/bin/env python3
from transformers import FlaxGPT2LMHeadModel
from flax.jax_utils import replicate
from jax import jit, pmap
import numpy as np

model = FlaxGPT2LMHeadModel.from_pretrained("gpt2-large")
dummy_inputs = np.array(4 * [256 * [1]], dtype=np.int32)


def run_forward(inputs, params):
    return model(inputs, params=params).logits


jitted_forward = jit(run_forward)

## simple forward
logits = jitted_forward(dummy_inputs, model.params)


## parallel forward
p_forward = pmap(run_forward)

p_params = replicate(model.params)
p_inptus = replicate(dummy_inputs)

logits = p_forward(p_inptus, p_params)