File size: 3,028 Bytes
2201cf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import json
from typing import Any, Dict, Optional
from dacite import Config as DaciteConfig
from dacite import from_dict
from omegaconf import OmegaConf
from transformers.configuration_utils import PretrainedConfig
from xlstm import xLSTMLMModelConfig
# from .config_presets import xlstm_cfg_map
class xLSTMConfig(PretrainedConfig):
"""XLSTM configuration class.
We seperate the specific xLSTM model configuration
from the rest due to the heavy nesting of the configuration.
"""
model_type = "xlstm"
def __init__(
self, vocab_size: int = 32000, config: Optional[Dict[str, Any]] = None, **kwargs
):
super().__init__(**kwargs)
cfg = OmegaConf.create(config)
cfg["vocab_size"] = vocab_size
for key, value in kwargs.items():
cfg[key] = value
self._xlstm_config = cfg
self.vocab_size = vocab_size
self.embedding_dim = cfg.get("embedding_dim")
self.context_length = cfg.get("context_length")
def to_xlstm_config(self):
return from_dict(
data_class=xLSTMLMModelConfig,
data=OmegaConf.to_container(self._xlstm_config),
config=DaciteConfig(strict=True),
)
def to_dict(self) -> Dict[str, Any]:
"""
Converts the configuration to a dictionary for serialization.
"""
output = super().to_dict()
output["_xlstm_config"] = OmegaConf.to_container(
self._xlstm_config, resolve=True
)
relevant_keys = [
"vocab_size",
"embedding_dim",
"context_length",
"torch_dtype",
"_xlstm_config",
"transformers_version",
"architectures",
"model_type",
]
output_ = output.copy()
for key in output.keys():
if key not in relevant_keys:
output_.pop(key)
return output_
@classmethod
def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
"""
Creates a configuration instance from a dictionary.
"""
xlstm_config = config_dict.pop("_xlstm_config")
vocab_size = config_dict.pop("vocab_size")
config = cls(vocab_size=vocab_size, config=xlstm_config)
if "auto_map" in config_dict and config_dict["auto_map"]:
setattr(config, "auto_map", config_dict.pop("auto_map"))
# breakpoint()
# config.xlstm_config = xlstm_config
if "return_unused_kwargs" in kwargs and kwargs["return_unused_kwargs"]:
return config, {}
return config
def to_json_string(self, *args, **kwargs) -> str:
"""
Serializes the instance to a JSON string.
"""
return json.dumps(self.to_dict(), indent=2)
@classmethod
def from_json_string(cls, json_string: str):
"""
Deserializes the instance from a JSON string.
"""
config_dict = json.loads(json_string)
return cls.from_dict(config_dict)
|