File size: 1,417 Bytes
d9b4e66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from transformers import PretrainedConfig
from transformers.utils import logging
from transformers.models.esm import EsmConfig
logger = logging.get_logger(__name__)
class ProtSTConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ProtSTModel`].
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
protein_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`EsmForProteinRepresentation`].
```"""
model_type = "protst"
def __init__(
self,
protein_config=None,
**kwargs,
):
super().__init__(**kwargs)
if protein_config is None:
protein_config = {}
logger.info("`protein_config` is `None`. Initializing the `ProtSTProteinConfig` with default values.")
self.protein_config = EsmConfig(**protein_config)
@classmethod
def from_protein_text_configs(
cls, protein_config: EsmConfig, **kwargs
):
r"""
Instantiate a [`ProtSTConfig`] (or a derived class) from ProtST text model configuration. Returns:
[`ProtSTConfig`]: An instance of a configuration object
"""
return cls(protein_config=protein_config.to_dict(), **kwargs) |