File size: 1,401 Bytes
225d271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import transformers

import model


class AbcTransformerConfig(transformers.PretrainedConfig):
    model_type = 'abc-transformer'
    def __init__(
            self,
            vocab_size=113,
            n_embd=384,
            block_size=128,
            n_heads=6,
            n_layers=6,
            dropout=0.2,
            device=None,
            **kwargs
    ):
        self.vocab_size = vocab_size
        self.n_embd = n_embd
        self.block_size = block_size
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.dropout = dropout
        self.device = device
        super().__init__(**kwargs)

class AbcTransformer(transformers.PreTrainedModel):
    config_class = AbcTransformerConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = model.AbcTransformer(
            vocab_size=config.vocab_size,
            n_embd=config.n_embd,
            block_size=config.block_size,
            n_heads=config.n_heads,
            n_layers=config.n_layers,
            dropout=config.dropout,
            device=config.device,
        )
    
    def forward(self, tensor, labels):
        return self.model(tensor, labels)

transformers.AutoConfig.register('abc-transformer', AbcTransformerConfig)
AbcTransformer.register_for_auto_class("AutoModelForCausalLM")
transformers.AutoModel.register(AbcTransformerConfig, AbcTransformer)