gpt2-layout-generation / run_forward_gpt2_large.py
Patrick von Platen
run forward
4023cca
#!/usr/bin/env python3
from transformers import FlaxGPT2LMHeadModel
from flax.jax_utils import replicate
from jax import jit, pmap
import numpy as np
model = FlaxGPT2LMHeadModel.from_pretrained("gpt2-large")
dummy_inputs = np.array(4 * [256 * [1]], dtype=np.int32)
def run_forward(inputs, params):
return model(inputs, params=params).logits
jitted_forward = jit(run_forward)
## simple forward
logits = jitted_forward(dummy_inputs, model.params)
## parallel forward
p_forward = pmap(run_forward)
p_params = replicate(model.params)
p_inptus = replicate(dummy_inputs)
logits = p_forward(p_inptus, p_params)