File size: 702 Bytes
c87c295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import LlamaModel, LlamaConfig
import torch.nn as nn

CODELLAMA_VOCAB_SIZE = 32016


class CodeLlamaForCausalLM(LlamaForCausalLM):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.model = LlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, CODELLAMA_VOCAB_SIZE, bias=False)
        self.model.embed_tokens = nn.Embedding(
            CODELLAMA_VOCAB_SIZE, config.hidden_size, config.pad_token_id
        )

        # Initialize weights and apply final processing
        self.post_init()