import torch from transformers import GPTNeoForCausalLM, AutoConfig, GPT2LMHeadModel from .utils import print_main from pathlib import Path from transformers.modeling_utils import no_init_weights LANGUAGE_MODELS = [ "gptj", ] def gptj_config(): config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-2.7B") config.attention_layers = ["global"] * 28 config.attention_types = [["global"], 28] config.num_layers = 28 config.num_heads = 16 config.hidden_size = 256 * config.num_heads config.vocab_size = 50400 config.rotary = True config.rotary_dim = 64 config.jax = True config.gradient_checkpointing = True return config def get_gptj( gradient_checkpointing: bool = True, from_pretrained=False, ) -> torch.nn.Module: """ Loads GPTJ language model from HF """ print_main("Loading GPTJ language model...") config = gptj_config() config.gradient_checkpointing = gradient_checkpointing if gradient_checkpointing: config.use_cache = False config.model_device = "cpu" if from_pretrained: raise NotImplemented("GPTJ pretrained not implemented") else: with no_init_weights(): model = GPTNeoForCausalLM(config=config) return model