#!/usr/bin/env python3 from transformers import FlaxGPT2LMHeadModel from flax.jax_utils import replicate from jax import jit, pmap import numpy as np model = FlaxGPT2LMHeadModel.from_pretrained("gpt2-large") dummy_inputs = np.array(4 * [256 * [1]], dtype=np.int32) def run_forward(inputs, params): return model(inputs, params=params).logits jitted_forward = jit(run_forward) ## simple forward logits = jitted_forward(dummy_inputs, model.params) ## parallel forward p_forward = pmap(run_forward) p_params = replicate(model.params) p_inptus = replicate(dummy_inputs) logits = p_forward(p_inptus, p_params)