File size: 702 Bytes
6fcd376 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import LlamaModel, LlamaConfig
import torch.nn as nn
CODELLAMA_VOCAB_SIZE = 32016
class CodeLlamaForCausalLM(LlamaForCausalLM):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, CODELLAMA_VOCAB_SIZE, bias=False)
self.model.embed_tokens = nn.Embedding(
CODELLAMA_VOCAB_SIZE, config.hidden_size, config.pad_token_id
)
# Initialize weights and apply final processing
self.post_init()
|