|
from transformers import PretrainedConfig |
|
|
|
class BilmaConfig(PretrainedConfig): |
|
model_type = "bilma" |
|
|
|
def __init__( |
|
self, |
|
weights="spanish", |
|
num_attention_heads: int = 4, |
|
num_hidden_layers: int = 2, |
|
seq_max_length: int = 280, |
|
hidden_size: int = 512, |
|
vocab_size: int = 28949, |
|
hidden_dropout_prob: float = 0.1, |
|
**kwargs, |
|
): |
|
if weights not in ["spanish", ""]: |
|
raise ValueError(f"`weights` must be 'spanish' or '', got {weights}.") |
|
if weights == "spanish": |
|
self.weights = weights |
|
self.num_attention_heads = 4 |
|
self.num_hidden_layers = 2 |
|
self.seq_max_length = 280 |
|
self.hidden_size = 512 |
|
self.vocab_size = 28949 |
|
self.hidden_dropout_prob = 0.1 |
|
super().__init__(**kwargs) |
|
return |
|
|
|
self.weights = weights |
|
self.num_attention_heads = num_attention_heads |
|
self.num_hidden_layers = num_hidden_layers |
|
self.seq_max_length = seq_max_length |
|
self.hidden_size = hidden_size |
|
self.vocab_size = vocab_size |
|
self.hidden_dropout_prob = hidden_dropout_prob |
|
super().__init__(**kwargs) |