|
from dataclasses import dataclass |
|
from transformers import PretrainedConfig |
|
|
|
@dataclass |
|
class GPTConfig(PretrainedConfig): |
|
""" |
|
Configuration class for custom GPT model. |
|
""" |
|
model_type = "custom_gpt" |
|
block_size: int = 768 |
|
vocab_size: int = 50257 |
|
n_layer: int = 8 |
|
n_head: int = 8 |
|
n_embd: int = 768 |
|
dropout: float = 0.1 |
|
|
|
@classmethod |
|
def from_pretrained(cls, *args, **kwargs): |
|
""" |
|
Override the from_pretrained method to handle custom configuration loading. |
|
""" |
|
return super().from_pretrained(*args, **kwargs) |