from transformers import PretrainedConfig, BertConfig from typing import List class VGCNConfig(BertConfig): model_type = "vgcn" def __init__( self, bert_model='readerbench/RoBERT-base', gcn_adj_matrix: str ='', max_seq_len: int = 256, npmi_threshold: float = 0.2, tf_threshold: float = 0.0, vocab_type: str = "all", gcn_embedding_dim: int = 32, **kwargs, ): if vocab_type not in ["all", "pmi", "tf"]: raise ValueError(f"`vocab_type` must be 'all', 'pmi' or 'tf', got {vocab_type}.") if max_seq_len < 1 or max_seq_len > 512: raise ValueError(f"`max_seq_len` must be between 1 and 512, got {max_seq_len}.") if npmi_threshold < 0.0 or npmi_threshold > 1.0: raise ValueError(f"`npmi_threshold` must be between 0.0 and 1.0, got {npmi_threshold}.") if tf_threshold < 0.0 or tf_threshold > 1.0: raise ValueError(f"`tf_threshold` must be between 0.0 and 1.0, got {tf_threshold}.") self.gcn_adj_matrix = gcn_adj_matrix self.max_seq_len = max_seq_len self.npmi_threshold = npmi_threshold self.tf_threshold = tf_threshold self.vocab_type = vocab_type self.gcn_embedding_dim = gcn_embedding_dim self.bert_model = bert_model super().__init__(**kwargs)