File size: 3,028 Bytes
2201cf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import json
from typing import Any, Dict, Optional

from dacite import Config as DaciteConfig
from dacite import from_dict
from omegaconf import OmegaConf
from transformers.configuration_utils import PretrainedConfig
from xlstm import xLSTMLMModelConfig

# from .config_presets import xlstm_cfg_map


class xLSTMConfig(PretrainedConfig):
    """XLSTM configuration class.
    We seperate the specific xLSTM model configuration
    from the rest due to the heavy nesting of the configuration.
    """

    model_type = "xlstm"

    def __init__(
        self, vocab_size: int = 32000, config: Optional[Dict[str, Any]] = None, **kwargs
    ):
        super().__init__(**kwargs)

        cfg = OmegaConf.create(config)
        cfg["vocab_size"] = vocab_size
        for key, value in kwargs.items():
            cfg[key] = value

        self._xlstm_config = cfg
        self.vocab_size = vocab_size
        self.embedding_dim = cfg.get("embedding_dim")
        self.context_length = cfg.get("context_length")

    def to_xlstm_config(self):
        return from_dict(
            data_class=xLSTMLMModelConfig,
            data=OmegaConf.to_container(self._xlstm_config),
            config=DaciteConfig(strict=True),
        )

    def to_dict(self) -> Dict[str, Any]:
        """
        Converts the configuration to a dictionary for serialization.
        """
        output = super().to_dict()
        output["_xlstm_config"] = OmegaConf.to_container(
            self._xlstm_config, resolve=True
        )
        relevant_keys = [
            "vocab_size",
            "embedding_dim",
            "context_length",
            "torch_dtype",
            "_xlstm_config",
            "transformers_version",
            "architectures",
            "model_type",
        ]
        output_ = output.copy()
        for key in output.keys():
            if key not in relevant_keys:
                output_.pop(key)
        return output_

    @classmethod
    def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
        """
        Creates a configuration instance from a dictionary.
        """
        xlstm_config = config_dict.pop("_xlstm_config")
        vocab_size = config_dict.pop("vocab_size")
        config = cls(vocab_size=vocab_size, config=xlstm_config)
        if "auto_map" in config_dict and config_dict["auto_map"]:
            setattr(config, "auto_map", config_dict.pop("auto_map"))

        # breakpoint()
        # config.xlstm_config = xlstm_config
        if "return_unused_kwargs" in kwargs and kwargs["return_unused_kwargs"]:
            return config, {}

        return config

    def to_json_string(self, *args, **kwargs) -> str:
        """
        Serializes the instance to a JSON string.
        """
        return json.dumps(self.to_dict(), indent=2)

    @classmethod
    def from_json_string(cls, json_string: str):
        """
        Deserializes the instance from a JSON string.
        """
        config_dict = json.loads(json_string)
        return cls.from_dict(config_dict)