from transformers import PretrainedConfig | |
from typing import List | |
class MLPConfig(PretrainedConfig): | |
model_type="mlp" | |
def __init__( | |
self, | |
input_size: int = 784, | |
output_size: int = 10, | |
hidden_size: int = 256, | |
**kwargs, | |
): | |
self.input_size = input_size | |
self.output_size = output_size | |
self.hidden_size = hidden_size | |
super().__init__(**kwargs) | |
MLPConfig.register_for_auto_class() | |