File size: 1,417 Bytes
d9b4e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from transformers import PretrainedConfig
from transformers.utils import logging
from transformers.models.esm import EsmConfig

logger = logging.get_logger(__name__)


class ProtSTConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`ProtSTModel`].
    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    Args:
        protein_config (`dict`, *optional*):
            Dictionary of configuration options used to initialize [`EsmForProteinRepresentation`].
    ```"""

    model_type = "protst"

    def __init__(
        self,
        protein_config=None,
        **kwargs,
    ):
        super().__init__(**kwargs)

        if protein_config is None:
            protein_config = {}
            logger.info("`protein_config` is `None`. Initializing the `ProtSTProteinConfig` with default values.")

        self.protein_config = EsmConfig(**protein_config)

    @classmethod
    def from_protein_text_configs(
        cls, protein_config: EsmConfig, **kwargs
    ):
        r"""
        Instantiate a [`ProtSTConfig`] (or a derived class) from ProtST text model configuration. Returns:
            [`ProtSTConfig`]: An instance of a configuration object
        """

        return cls(protein_config=protein_config.to_dict(), **kwargs)