pt_custom_model_example / configuration_my_model.py
ydshieh
add config
981428d
raw
history blame
270 Bytes
from transformers import PretrainedConfig
class MyModelConfig(PretrainedConfig):
model_type = "my_model"
def __init__(
self,
n_layers=2,
**kwargs,
):
self.n_layers = n_layers
super().__init__(**kwargs)