Jiqing's picture
Create configuration_protst.py
d9b4e66 verified
raw
history blame
1.42 kB
from transformers import PretrainedConfig
from transformers.utils import logging
from transformers.models.esm import EsmConfig
logger = logging.get_logger(__name__)
class ProtSTConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ProtSTModel`].
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
protein_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`EsmForProteinRepresentation`].
```"""
model_type = "protst"
def __init__(
self,
protein_config=None,
**kwargs,
):
super().__init__(**kwargs)
if protein_config is None:
protein_config = {}
logger.info("`protein_config` is `None`. Initializing the `ProtSTProteinConfig` with default values.")
self.protein_config = EsmConfig(**protein_config)
@classmethod
def from_protein_text_configs(
cls, protein_config: EsmConfig, **kwargs
):
r"""
Instantiate a [`ProtSTConfig`] (or a derived class) from ProtST text model configuration. Returns:
[`ProtSTConfig`]: An instance of a configuration object
"""
return cls(protein_config=protein_config.to_dict(), **kwargs)