File size: 2,295 Bytes
13df84c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import warnings

from transformers import PretrainedConfig
from transformers import CONFIG_MAPPING

from .configuration_moment import MomentConfig

class MistsConfig(PretrainedConfig):
    model_type = "mists"

    def __init__(
        self, 
        time_series_config=None,
        text_config=None,
        ignore_index=-100,
        time_series_token_index=32000,
        projector_hidden_act="gelu",  # projector用
        # time_series_feature_select_strategy="default",  # TODO: modelのforward用(画像モデルのhidden_stateからEmbeddingをどう取得するか)。将来的に対応。
        # time_series_feature_layer=-2,  # modelのforward用  # TODO: modelのforward用(画像モデルのhidden_stateからEmbeddingをどう取得するか)。将来的に対応。
        time_series_hidden_size=1024,  # projector用
        **kwargs,
    ):
        
        self.ignore_index = ignore_index
        self.time_series_token_index = time_series_token_index
        self.projector_hidden_act = projector_hidden_act
        self.time_series_hidden_size = time_series_hidden_size

        # 将来的に、MomentモデルがTransformersに登録されることを想定して追加する
        # そのため、CONFIG_MAPPINGは機能しない。
        if isinstance(time_series_config, dict):
            time_series_config["model_type"] = (
                time_series_config["model_type"] if "model_type" in time_series_config else "moment"
            )
            # time_series_config = CONFIG_MAPPING[time_series_config["model_type"]](**time_series_config)
            time_series_config = MomentConfig(**time_series_config)
        elif time_series_config is None:
            time_series_config = MomentConfig()

        self.time_series_config = time_series_config

        if isinstance(text_config, dict):
            text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "mistral"
            text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
        elif text_config is None:
            text_config = CONFIG_MAPPING["mistral"]()

        self.text_config = text_config

        super().__init__(**kwargs)


    def to_dict(self):
        output = super().to_dict()
        return output