from transformers import GPT2Config class TranceptionConfig(GPT2Config): """ Config subclass for Tranception model architecture. """ def __init__( self, attention_mode="tranception", position_embedding="grouped_alibi", tokenizer=None, retrieval_aggregation_mode=None, retrieval_inference_weight=0.6, MSA_filename=None, MSA_weight_file_name=None, MSA_start=None, MSA_end=None, full_protein_length=None, clustal_omega_location=None, scoring_window=None, **kwargs ): super().__init__(**kwargs) self.model_type="tranception" self.attention_mode=attention_mode self.position_embedding=position_embedding self.tokenizer = tokenizer self.retrieval_aggregation_mode = retrieval_aggregation_mode self.retrieval_inference_weight = retrieval_inference_weight self.MSA_filename = MSA_filename self.MSA_weight_file_name = MSA_weight_file_name self.MSA_start=MSA_start self.MSA_end=MSA_end self.full_protein_length = full_protein_length self.clustal_omega_location = clustal_omega_location self.scoring_window=scoring_window