|
|
|
from transformers import FlaxGPT2LMHeadModel |
|
from flax.jax_utils import replicate |
|
from jax import jit, pmap |
|
import numpy as np |
|
|
|
model = FlaxGPT2LMHeadModel.from_pretrained("gpt2-large") |
|
dummy_inputs = np.array(4 * [256 * [1]], dtype=np.int32) |
|
|
|
|
|
def run_forward(inputs, params): |
|
return model(inputs, params=params).logits |
|
|
|
|
|
jitted_forward = jit(run_forward) |
|
|
|
|
|
logits = jitted_forward(dummy_inputs, model.params) |
|
|
|
|
|
|
|
p_forward = pmap(run_forward) |
|
|
|
p_params = replicate(model.params) |
|
p_inptus = replicate(dummy_inputs) |
|
|
|
logits = p_forward(p_inptus, p_params) |
|
|