|
import warnings |
|
|
|
from transformers import PretrainedConfig |
|
from transformers import CONFIG_MAPPING |
|
|
|
from .configuration_moment import MomentConfig |
|
|
|
class MistsConfig(PretrainedConfig): |
|
model_type = "mists" |
|
|
|
def __init__( |
|
self, |
|
time_series_config=None, |
|
text_config=None, |
|
ignore_index=-100, |
|
time_series_token_index=32000, |
|
projector_hidden_act="gelu", |
|
|
|
|
|
time_series_hidden_size=1024, |
|
**kwargs, |
|
): |
|
|
|
self.ignore_index = ignore_index |
|
self.time_series_token_index = time_series_token_index |
|
self.projector_hidden_act = projector_hidden_act |
|
self.time_series_hidden_size = time_series_hidden_size |
|
|
|
|
|
|
|
if isinstance(time_series_config, dict): |
|
time_series_config["model_type"] = ( |
|
time_series_config["model_type"] if "model_type" in time_series_config else "moment" |
|
) |
|
|
|
time_series_config = MomentConfig(**time_series_config) |
|
elif time_series_config is None: |
|
time_series_config = MomentConfig() |
|
|
|
self.time_series_config = time_series_config |
|
|
|
if isinstance(text_config, dict): |
|
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "mistral" |
|
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) |
|
elif text_config is None: |
|
text_config = CONFIG_MAPPING["mistral"]() |
|
|
|
self.text_config = text_config |
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
def to_dict(self): |
|
output = super().to_dict() |
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|