import jax import jax.numpy as jnp from transformers import FlaxGPTNeoForCausalLM, GPTNeoConfig model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B") emb = jnp.zeros((50264, model.config.hidden_size)) # update the first 50257 weights using pre-trained weights emb = jax.ops.index_update(emb, jax.ops.index[:50257, :], model.params["transformer"]["wte"]["embedding"]) params = model.params params["transformer"]["wte"]["embedding"] = emb # initialize a random model with the right vocab_size config = GPTNeoConfig.from_pretrained("EleutherAI/gpt-neo-1.3B", vocab_size=50264) model = FlaxGPTNeoForCausalLM(config) # assign the pre-trained weights and save the model. model.params = params model.save_pretrained("gpt-neo-1.3B")