magma / magma /language_model.py
stellaathena's picture
This should work
bb5cd12
import torch
from transformers import GPTNeoForCausalLM, AutoConfig, GPT2LMHeadModel
from .utils import print_main
from pathlib import Path
from transformers.modeling_utils import no_init_weights
LANGUAGE_MODELS = [
"gptj",
]
def gptj_config():
config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-2.7B")
config.attention_layers = ["global"] * 28
config.attention_types = [["global"], 28]
config.num_layers = 28
config.num_heads = 16
config.hidden_size = 256 * config.num_heads
config.vocab_size = 50400
config.rotary = True
config.rotary_dim = 64
config.jax = True
config.gradient_checkpointing = True
return config
def get_gptj(
gradient_checkpointing: bool = True,
from_pretrained=False,
) -> torch.nn.Module:
"""
Loads GPTJ language model from HF
"""
print_main("Loading GPTJ language model...")
config = gptj_config()
config.gradient_checkpointing = gradient_checkpointing
if gradient_checkpointing:
config.use_cache = False
config.model_device = "cpu"
if from_pretrained:
raise NotImplemented("GPTJ pretrained not implemented")
else:
with no_init_weights():
model = GPTNeoForCausalLM(config=config)
return model