import torch from transformers import BertModel from .configuration import NewModelConfig class NewModel(BertModel): config_class = NewModelConfig def __init__(self, config): super().__init__(config) self.last_layer = torch.nn.Linear(config.hidden_size, config.new_hidden_size)