|
import transformers |
|
|
|
import model |
|
|
|
|
|
class AbcTransformerConfig(transformers.PretrainedConfig): |
|
model_type = 'abc-transformer' |
|
def __init__( |
|
self, |
|
vocab_size=113, |
|
n_embd=384, |
|
block_size=128, |
|
n_heads=6, |
|
n_layers=6, |
|
dropout=0.2, |
|
device=None, |
|
**kwargs |
|
): |
|
self.vocab_size = vocab_size |
|
self.n_embd = n_embd |
|
self.block_size = block_size |
|
self.n_heads = n_heads |
|
self.n_layers = n_layers |
|
self.dropout = dropout |
|
self.device = device |
|
super().__init__(**kwargs) |
|
|
|
class AbcTransformer(transformers.PreTrainedModel): |
|
config_class = AbcTransformerConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = model.AbcTransformer( |
|
vocab_size=config.vocab_size, |
|
n_embd=config.n_embd, |
|
block_size=config.block_size, |
|
n_heads=config.n_heads, |
|
n_layers=config.n_layers, |
|
dropout=config.dropout, |
|
device=config.device, |
|
) |
|
|
|
def forward(self, tensor, labels): |
|
return self.model(tensor, labels) |
|
|
|
transformers.AutoConfig.register('abc-transformer', AbcTransformerConfig) |
|
AbcTransformer.register_for_auto_class("AutoModelForCausalLM") |
|
transformers.AutoModel.register(AbcTransformerConfig, AbcTransformer) |
|
|