import tensorflow as tf from transformers.modeling_tf_utils import unpack_inputs from transformers.modeling_tf_utils import TFPreTrainedModel from .configuration_my_model import MyModelConfig class TFMyModelPretrainedModel(TFPreTrainedModel): config_class = MyModelConfig class TFMyModel(TFMyModelPretrainedModel): def __init__(self, config: MyModelConfig): super().__init__(config) self.config = config self.n_layers = config.n_layers self.hidden_dim = config.hidden_dim self.linear = tf.keras.layers.Dense(units=config.n_layers) @property def dummy_inputs(self): hidden = tf.zeros(shape=(1, self.config.hidden_dim)) dummy_inputs = {"hidden": hidden} return dummy_inputs @unpack_inputs def call( self, hidden, output_attentions=False, output_hidden_states=False, return_dict=False, ): breakpoint() self.linear(hidden)