from transformers import PretrainedConfig class MultiHeadConfig(PretrainedConfig): model_type = "multihead" def __init__( self, encoder_name="microsoft/deberta-v3-small", **kwargs ): self.encoder_name = encoder_name self.classifier_dropout = kwargs.get("classifier_dropout", 0.1) self.num_labels = kwargs.get("num_labels", 2) self.id2label = kwargs.get("id2label", {0: "irrelevant", 1: "relevant"}) self.label2id = kwargs.get("label2id", {"irrelevant": 0, "relevant": 1}) self.tokenizer_class = kwargs.get("tokenizer_class", "DebertaV2TokenizerFast") super().__init__(**kwargs)