File size: 618 Bytes
473506b 4023cca 473506b 4023cca 473506b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
#!/usr/bin/env python3
from transformers import FlaxGPT2LMHeadModel
from flax.jax_utils import replicate
from jax import jit, pmap
import numpy as np
model = FlaxGPT2LMHeadModel.from_pretrained("gpt2-large")
dummy_inputs = np.array(4 * [256 * [1]], dtype=np.int32)
def run_forward(inputs, params):
return model(inputs, params=params).logits
jitted_forward = jit(run_forward)
## simple forward
logits = jitted_forward(dummy_inputs, model.params)
## parallel forward
p_forward = pmap(run_forward)
p_params = replicate(model.params)
p_inptus = replicate(dummy_inputs)
logits = p_forward(p_inptus, p_params)
|