norwegian-gptneo-red / setup_devices.py
pere's picture
test
bf8b191
raw
history blame
748 Bytes
import jax
import jax.numpy as jnp
from transformers import FlaxGPTNeoForCausalLM, GPTNeoConfig
model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
emb = jnp.zeros((50264, model.config.hidden_size))
# update the first 50257 weights using pre-trained weights
emb = jax.ops.index_update(emb, jax.ops.index[:50257, :], model.params["transformer"]["wte"]["embedding"])
params = model.params
params["transformer"]["wte"]["embedding"] = emb
# initialize a random model with the right vocab_size
config = GPTNeoConfig.from_pretrained("EleutherAI/gpt-neo-1.3B", vocab_size=50264)
model = FlaxGPTNeoForCausalLM(config)
# assign the pre-trained weights and save the model.
model.params = params
model.save_pretrained("gpt-neo-1.3B")