from transformers import PretrainedConfig from transformers.utils import logging from transformers.models.esm import EsmConfig logger = logging.get_logger(__name__) class ProtSTConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`ProtSTModel`]. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: protein_config (`dict`, *optional*): Dictionary of configuration options used to initialize [`EsmForProteinRepresentation`]. ```""" model_type = "protst" def __init__( self, protein_config=None, **kwargs, ): super().__init__(**kwargs) if protein_config is None: protein_config = {} logger.info("`protein_config` is `None`. Initializing the `ProtSTProteinConfig` with default values.") self.protein_config = EsmConfig(**protein_config) @classmethod def from_protein_text_configs( cls, protein_config: EsmConfig, **kwargs ): r""" Instantiate a [`ProtSTConfig`] (or a derived class) from ProtST text model configuration. Returns: [`ProtSTConfig`]: An instance of a configuration object """ return cls(protein_config=protein_config.to_dict(), **kwargs)