|
import jax |
|
import jax.numpy as jnp |
|
from transformers import FlaxGPTNeoForCausalLM, GPTNeoConfig |
|
model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B") |
|
|
|
emb = jnp.zeros((50264, model.config.hidden_size)) |
|
|
|
emb = jax.ops.index_update(emb, jax.ops.index[:50257, :], model.params["transformer"]["wte"]["embedding"]) |
|
params = model.params |
|
params["transformer"]["wte"]["embedding"] = emb |
|
|
|
|
|
config = GPTNeoConfig.from_pretrained("EleutherAI/gpt-neo-1.3B", vocab_size=50264) |
|
model = FlaxGPTNeoForCausalLM(config) |
|
|
|
|
|
model.params = params |
|
model.save_pretrained("gpt-neo-1.3B") |
|
|