import tensorflow as tf from transformers.modeling_tf_utils import TFPreTrainedModel from configuration_my_model import MyModelConfig class TFMyModelPretrainedModel(TFPreTrainedModel): config_class = MyModelConfig class TFMyModel(TFMyModelPretrainedModel): def __init__(self, config: MyModelConfig): super().__init__(config) self.config = config self.n_layers = config.n_layers self.hidden_dim = config.hidden_dim self.linear = tf.keras.layers.Dense(units=config.n_layers) config = MyModelConfig() model = TFMyModel(config) print(model) model.save_pretrained("my_model")