Spaces:
Runtime error
Runtime error
File size: 1,266 Bytes
bb5cd12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import torch
from transformers import GPTNeoForCausalLM, AutoConfig, GPT2LMHeadModel
from .utils import print_main
from pathlib import Path
from transformers.modeling_utils import no_init_weights
LANGUAGE_MODELS = [
"gptj",
]
def gptj_config():
config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-2.7B")
config.attention_layers = ["global"] * 28
config.attention_types = [["global"], 28]
config.num_layers = 28
config.num_heads = 16
config.hidden_size = 256 * config.num_heads
config.vocab_size = 50400
config.rotary = True
config.rotary_dim = 64
config.jax = True
config.gradient_checkpointing = True
return config
def get_gptj(
gradient_checkpointing: bool = True,
from_pretrained=False,
) -> torch.nn.Module:
"""
Loads GPTJ language model from HF
"""
print_main("Loading GPTJ language model...")
config = gptj_config()
config.gradient_checkpointing = gradient_checkpointing
if gradient_checkpointing:
config.use_cache = False
config.model_device = "cpu"
if from_pretrained:
raise NotImplemented("GPTJ pretrained not implemented")
else:
with no_init_weights():
model = GPTNeoForCausalLM(config=config)
return model
|